@@ -379,9 +379,10 @@ def prepare_inputs_for_generation(
379379 # If we have cache: let's slice `input_ids` through `cache_position`, to keep only the unprocessed tokens
380380 # Exception 1: when passing input_embeds, input_ids may be missing entries
381381 # Exception 2: some generation methods do special slicing of input_ids, so we don't need to do it here
382+ # Exception 3: with synced GPUs cache_position may go out of bounds, but we only want dummy token in that case
382383 if past_key_values is not None :
383384 model_inputs ["past_key_values" ] = past_key_values
384- if inputs_embeds is not None : # Exception 1
385+ if inputs_embeds is not None or cache_position [ - 1 ] >= input_ids . shape [ 1 ] : # Exception 1 or Exception 3
385386 input_ids = input_ids [:, - cache_position .shape [0 ] :]
386387 elif input_ids .shape [1 ] != cache_position .shape [0 ]: # Default case (the "else", a no op, is Exception 2)
387388 input_ids = input_ids [:, cache_position ]
@@ -2609,8 +2610,14 @@ def _dola_decoding(
26092610 outputs .hidden_states [candidate_premature_layer ][:, - 1 , :]
26102611 ).to (final_logits .device )
26112612
2613+ # synced_gpus: don't waste resources running the code we don't need; kwargs must be updated before skipping
2614+ model_kwargs = self ._update_model_kwargs_for_generation (
2615+ outputs ,
2616+ model_kwargs ,
2617+ is_encoder_decoder = self .config .is_encoder_decoder ,
2618+ )
26122619 if synced_gpus and this_peer_finished :
2613- continue # don't waste resources running the code we don't need
2620+ continue
26142621
26152622 next_token_logits = _dola_select_contrast (
26162623 candidate_premature_layers , candidate_premature_logits , final_logits
@@ -2652,11 +2659,6 @@ def _dola_decoding(
26522659 input_ids = torch .cat ([input_ids , next_tokens [:, None ]], dim = - 1 )
26532660 if streamer is not None :
26542661 streamer .put (next_tokens .cpu ())
2655- model_kwargs = self ._update_model_kwargs_for_generation (
2656- outputs ,
2657- model_kwargs ,
2658- is_encoder_decoder = self .config .is_encoder_decoder ,
2659- )
26602662
26612663 # stop when each sentence is finished
26622664 unfinished_sequences = unfinished_sequences & ~ stopping_criteria (input_ids , scores )
@@ -3016,8 +3018,14 @@ def _contrastive_search(
30163018 )
30173019 # contrastive_search main logic end
30183020
3021+ # synced_gpus: don't waste resources running the code we don't need; kwargs must be updated before skipping
3022+ model_kwargs = self ._update_model_kwargs_for_generation (
3023+ outputs ,
3024+ model_kwargs ,
3025+ is_encoder_decoder = self .config .is_encoder_decoder ,
3026+ )
30193027 if synced_gpus and this_peer_finished :
3020- continue # don't waste resources running the code we don't need
3028+ continue
30213029
30223030 # finished sentences should have their next token be a padding token
30233031 if has_eos_stopping_criteria :
@@ -3027,11 +3035,6 @@ def _contrastive_search(
30273035 input_ids = torch .cat ([input_ids , next_tokens [:, None ]], dim = - 1 )
30283036 if streamer is not None :
30293037 streamer .put (next_tokens .cpu ())
3030- model_kwargs = self ._update_model_kwargs_for_generation (
3031- outputs ,
3032- model_kwargs ,
3033- is_encoder_decoder = self .config .is_encoder_decoder ,
3034- )
30353038
30363039 # stop when each sentence is finished
30373040 unfinished_sequences = unfinished_sequences & ~ stopping_criteria (input_ids , scores )
@@ -3168,8 +3171,14 @@ def _sample(
31683171 # forward pass to get next token
31693172 outputs = self (** model_inputs , return_dict = True )
31703173
3174+ # synced_gpus: don't waste resources running the code we don't need; kwargs must be updated before skipping
3175+ model_kwargs = self ._update_model_kwargs_for_generation (
3176+ outputs ,
3177+ model_kwargs ,
3178+ is_encoder_decoder = self .config .is_encoder_decoder ,
3179+ )
31713180 if synced_gpus and this_peer_finished :
3172- continue # don't waste resources running the code we don't need
3181+ continue
31733182
31743183 # Clone is needed to avoid keeping a hanging ref to outputs.logits which may be very large for first iteration
31753184 # (the clone itself is always small)
@@ -3214,11 +3223,6 @@ def _sample(
32143223 input_ids = torch .cat ([input_ids , next_tokens [:, None ]], dim = - 1 )
32153224 if streamer is not None :
32163225 streamer .put (next_tokens .cpu ())
3217- model_kwargs = self ._update_model_kwargs_for_generation (
3218- outputs ,
3219- model_kwargs ,
3220- is_encoder_decoder = self .config .is_encoder_decoder ,
3221- )
32223226
32233227 unfinished_sequences = unfinished_sequences & ~ stopping_criteria (input_ids , scores )
32243228 this_peer_finished = unfinished_sequences .max () == 0
@@ -3415,9 +3419,15 @@ def _beam_search(
34153419 else : # Unchanged original behavior
34163420 outputs = self (** model_inputs , return_dict = True )
34173421
3422+ # synced_gpus: don't waste resources running the code we don't need; kwargs must be updated before skipping
3423+ model_kwargs = self ._update_model_kwargs_for_generation (
3424+ outputs ,
3425+ model_kwargs ,
3426+ is_encoder_decoder = self .config .is_encoder_decoder ,
3427+ )
34183428 if synced_gpus and this_peer_finished :
34193429 cur_len = cur_len + 1
3420- continue # don't waste resources running the code we don't need
3430+ continue
34213431
34223432 # Clone is needed to avoid keeping a hanging ref to outputs.logits which may be very large for first iteration
34233433 # (the clone itself is always small)
@@ -3491,12 +3501,6 @@ def _beam_search(
34913501
34923502 input_ids = torch .cat ([input_ids [beam_idx , :], beam_next_tokens .unsqueeze (- 1 )], dim = - 1 )
34933503
3494- model_kwargs = self ._update_model_kwargs_for_generation (
3495- outputs ,
3496- model_kwargs ,
3497- is_encoder_decoder = self .config .is_encoder_decoder ,
3498- )
3499-
35003504 # This is needed to properly delete outputs.logits which may be very large for first iteration
35013505 # Otherwise a reference to outputs is kept which keeps the logits alive in the next iteration
35023506 # IMPORTANT: Note that this should appear BEFORE the call to _reorder_cache() to save the maximum memory
@@ -3670,9 +3674,15 @@ def _group_beam_search(
36703674
36713675 outputs = self (** model_inputs , return_dict = True )
36723676
3677+ # synced_gpus: don't waste resources running the code we don't need; kwargs must be updated before skipping
3678+ model_kwargs = self ._update_model_kwargs_for_generation (
3679+ outputs ,
3680+ model_kwargs ,
3681+ is_encoder_decoder = self .config .is_encoder_decoder ,
3682+ )
36733683 if synced_gpus and this_peer_finished :
36743684 cur_len = cur_len + 1
3675- continue # don't waste resources running the code we don't need
3685+ continue
36763686
36773687 if output_scores :
36783688 processed_score = torch .zeros_like (outputs .logits [:, - 1 , :])
@@ -3782,12 +3792,6 @@ def _group_beam_search(
37823792
37833793 input_ids = torch .cat ([input_ids , current_tokens .unsqueeze (- 1 )], dim = - 1 )
37843794
3785- model_kwargs = self ._update_model_kwargs_for_generation (
3786- outputs ,
3787- model_kwargs ,
3788- is_encoder_decoder = self .config .is_encoder_decoder ,
3789- )
3790-
37913795 # This is needed to properly delete outputs.logits which may be very large for first iteration
37923796 # Otherwise a reference to outputs is kept which keeps the logits alive in the next iteration
37933797 # IMPORTANT: Note that this should appear BEFORE the call to _reorder_cache() to save the maximum memory
@@ -3948,9 +3952,15 @@ def _constrained_beam_search(
39483952
39493953 outputs = self (** model_inputs , return_dict = True )
39503954
3955+ # synced_gpus: don't waste resources running the code we don't need; kwargs must be updated before skipping
3956+ model_kwargs = self ._update_model_kwargs_for_generation (
3957+ outputs ,
3958+ model_kwargs ,
3959+ is_encoder_decoder = self .config .is_encoder_decoder ,
3960+ )
39513961 if synced_gpus and this_peer_finished :
39523962 cur_len = cur_len + 1
3953- continue # don't waste resources running the code we don't need
3963+ continue
39543964
39553965 # Clone is needed to avoid keeping a hanging ref to outputs.logits which may be very large for first iteration
39563966 # (the clone itself is always small)
@@ -4018,11 +4028,6 @@ def _constrained_beam_search(
40184028 beam_idx = beam_outputs ["next_beam_indices" ]
40194029
40204030 input_ids = torch .cat ([input_ids [beam_idx , :], beam_next_tokens .unsqueeze (- 1 )], dim = - 1 )
4021- model_kwargs = self ._update_model_kwargs_for_generation (
4022- outputs ,
4023- model_kwargs ,
4024- is_encoder_decoder = self .config .is_encoder_decoder ,
4025- )
40264031
40274032 # This is needed to properly delete outputs.logits which may be very large for first iteration
40284033 # Otherwise a reference to outputs is kept which keeps the logits alive in the next iteration
@@ -4162,17 +4167,8 @@ def _assisted_decoding(
41624167 unfinished_sequences = torch .ones (batch_size , dtype = torch .long , device = input_ids .device )
41634168 model_kwargs = self ._get_initial_cache_position (input_ids , model_kwargs )
41644169
4165- # This is needed if return_dict_in_generate is True
4166- start_from_empty_dynamic_cache = False
4167- past_key_values = model_kwargs .get ("past_key_values" , None )
4168- if isinstance (past_key_values , DynamicCache ) or (
4169- isinstance (past_key_values , EncoderDecoderCache )
4170- and isinstance (past_key_values .self_attention_cache , DynamicCache )
4171- ):
4172- if past_key_values .get_seq_length () == 0 :
4173- start_from_empty_dynamic_cache = True
4174-
41754170 this_peer_finished = False
4171+ is_first_iteration = True # to preserve the same API in the output as other generation methods
41764172 while self ._has_unfinished_sequences (this_peer_finished , synced_gpus , device = input_ids .device ):
41774173 cur_len = input_ids .shape [- 1 ]
41784174
@@ -4271,63 +4267,59 @@ def _assisted_decoding(
42714267 # 5. Update the candidate generation strategy if needed
42724268 candidate_generator .update_candidate_strategy (input_ids , new_logits , n_matches )
42734269
4270+ # synced_gpus: don't waste resources running the code we don't need; kwargs must be updated before skipping
4271+ model_kwargs = self ._update_model_kwargs_for_generation (
4272+ outputs ,
4273+ model_kwargs ,
4274+ is_encoder_decoder = self .config .is_encoder_decoder ,
4275+ num_new_tokens = n_matches + 1 ,
4276+ )
42744277 if synced_gpus and this_peer_finished :
4275- continue # don't waste resources running the code we don't need
4278+ continue
42764279
42774280 # Store scores, attentions and hidden_states when required
42784281 # Assistant: modified to append one tuple element per token, as in the other generation methods.
42794282 if return_dict_in_generate :
4283+ newly_added_length = n_matches + 1
42804284 if output_scores :
4281- scores += tuple (new_logits [:, i , :] for i in range (n_matches + 1 ))
4285+ scores += tuple (new_logits [:, i , :] for i in range (newly_added_length ))
42824286 if output_logits :
4283- raw_logits += (next_token_logits ,)
4284-
4285- if "past_key_values" not in model_kwargs or start_from_empty_dynamic_cache :
4286- added_len = new_cur_len
4287- # set it to false for other iterations
4288- start_from_empty_dynamic_cache = False
4289- else :
4290- added_len = n_matches + 1
4287+ raw_logits += tuple (next_token_logits [:, i , :] for i in range (newly_added_length ))
42914288
4289+ newly_added_length = new_cur_len if is_first_iteration else newly_added_length
42924290 if output_attentions :
42934291 if self .config .is_encoder_decoder :
42944292 cross_attentions = _split_model_outputs (
4295- cross_attentions , outputs .cross_attentions , cur_len , added_len
4293+ cross_attentions , outputs .cross_attentions , cur_len , newly_added_length
42964294 )
42974295 decoder_attentions = _split_model_outputs (
42984296 decoder_attentions ,
42994297 outputs .decoder_attentions ,
43004298 cur_len ,
4301- added_len ,
4299+ newly_added_length ,
43024300 is_decoder_attention = True ,
43034301 )
43044302 else :
43054303 decoder_attentions = _split_model_outputs (
43064304 decoder_attentions ,
43074305 outputs .attentions ,
43084306 cur_len ,
4309- added_len ,
4307+ newly_added_length ,
43104308 is_decoder_attention = True ,
43114309 )
43124310 if output_hidden_states :
43134311 if self .config .is_encoder_decoder :
43144312 decoder_hidden_states = _split_model_outputs (
4315- decoder_hidden_states , outputs .decoder_hidden_states , cur_len , added_len
4313+ decoder_hidden_states , outputs .decoder_hidden_states , cur_len , newly_added_length
43164314 )
43174315 else :
43184316 decoder_hidden_states = _split_model_outputs (
4319- decoder_hidden_states , outputs .hidden_states , cur_len , added_len
4317+ decoder_hidden_states , outputs .hidden_states , cur_len , newly_added_length
43204318 )
43214319
4322- model_kwargs = self ._update_model_kwargs_for_generation (
4323- outputs ,
4324- model_kwargs ,
4325- is_encoder_decoder = self .config .is_encoder_decoder ,
4326- num_new_tokens = n_matches + 1 ,
4327- )
4328-
43294320 unfinished_sequences = unfinished_sequences & ~ stopping_criteria (input_ids , scores )
43304321 this_peer_finished = unfinished_sequences .max () == 0
4322+ is_first_iteration = False
43314323
43324324 if streamer is not None :
43334325 streamer .end ()
0 commit comments