Skip to content

Commit 37ea040

Browse files
authored
Generate: Fix modern llm generate calls with synced_gpus (#34095)
1 parent 617b212 commit 37ea040

File tree

1 file changed

+63
-71
lines changed

1 file changed

+63
-71
lines changed

src/transformers/generation/utils.py

Lines changed: 63 additions & 71 deletions
Original file line numberDiff line numberDiff line change
@@ -379,9 +379,10 @@ def prepare_inputs_for_generation(
379379
# If we have cache: let's slice `input_ids` through `cache_position`, to keep only the unprocessed tokens
380380
# Exception 1: when passing input_embeds, input_ids may be missing entries
381381
# Exception 2: some generation methods do special slicing of input_ids, so we don't need to do it here
382+
# Exception 3: with synced GPUs cache_position may go out of bounds, but we only want dummy token in that case
382383
if past_key_values is not None:
383384
model_inputs["past_key_values"] = past_key_values
384-
if inputs_embeds is not None: # Exception 1
385+
if inputs_embeds is not None or cache_position[-1] >= input_ids.shape[1]: # Exception 1 or Exception 3
385386
input_ids = input_ids[:, -cache_position.shape[0] :]
386387
elif input_ids.shape[1] != cache_position.shape[0]: # Default case (the "else", a no op, is Exception 2)
387388
input_ids = input_ids[:, cache_position]
@@ -2609,8 +2610,14 @@ def _dola_decoding(
26092610
outputs.hidden_states[candidate_premature_layer][:, -1, :]
26102611
).to(final_logits.device)
26112612

2613+
# synced_gpus: don't waste resources running the code we don't need; kwargs must be updated before skipping
2614+
model_kwargs = self._update_model_kwargs_for_generation(
2615+
outputs,
2616+
model_kwargs,
2617+
is_encoder_decoder=self.config.is_encoder_decoder,
2618+
)
26122619
if synced_gpus and this_peer_finished:
2613-
continue # don't waste resources running the code we don't need
2620+
continue
26142621

26152622
next_token_logits = _dola_select_contrast(
26162623
candidate_premature_layers, candidate_premature_logits, final_logits
@@ -2652,11 +2659,6 @@ def _dola_decoding(
26522659
input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1)
26532660
if streamer is not None:
26542661
streamer.put(next_tokens.cpu())
2655-
model_kwargs = self._update_model_kwargs_for_generation(
2656-
outputs,
2657-
model_kwargs,
2658-
is_encoder_decoder=self.config.is_encoder_decoder,
2659-
)
26602662

26612663
# stop when each sentence is finished
26622664
unfinished_sequences = unfinished_sequences & ~stopping_criteria(input_ids, scores)
@@ -3016,8 +3018,14 @@ def _contrastive_search(
30163018
)
30173019
# contrastive_search main logic end
30183020

3021+
# synced_gpus: don't waste resources running the code we don't need; kwargs must be updated before skipping
3022+
model_kwargs = self._update_model_kwargs_for_generation(
3023+
outputs,
3024+
model_kwargs,
3025+
is_encoder_decoder=self.config.is_encoder_decoder,
3026+
)
30193027
if synced_gpus and this_peer_finished:
3020-
continue # don't waste resources running the code we don't need
3028+
continue
30213029

30223030
# finished sentences should have their next token be a padding token
30233031
if has_eos_stopping_criteria:
@@ -3027,11 +3035,6 @@ def _contrastive_search(
30273035
input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1)
30283036
if streamer is not None:
30293037
streamer.put(next_tokens.cpu())
3030-
model_kwargs = self._update_model_kwargs_for_generation(
3031-
outputs,
3032-
model_kwargs,
3033-
is_encoder_decoder=self.config.is_encoder_decoder,
3034-
)
30353038

30363039
# stop when each sentence is finished
30373040
unfinished_sequences = unfinished_sequences & ~stopping_criteria(input_ids, scores)
@@ -3168,8 +3171,14 @@ def _sample(
31683171
# forward pass to get next token
31693172
outputs = self(**model_inputs, return_dict=True)
31703173

3174+
# synced_gpus: don't waste resources running the code we don't need; kwargs must be updated before skipping
3175+
model_kwargs = self._update_model_kwargs_for_generation(
3176+
outputs,
3177+
model_kwargs,
3178+
is_encoder_decoder=self.config.is_encoder_decoder,
3179+
)
31713180
if synced_gpus and this_peer_finished:
3172-
continue # don't waste resources running the code we don't need
3181+
continue
31733182

31743183
# Clone is needed to avoid keeping a hanging ref to outputs.logits which may be very large for first iteration
31753184
# (the clone itself is always small)
@@ -3214,11 +3223,6 @@ def _sample(
32143223
input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1)
32153224
if streamer is not None:
32163225
streamer.put(next_tokens.cpu())
3217-
model_kwargs = self._update_model_kwargs_for_generation(
3218-
outputs,
3219-
model_kwargs,
3220-
is_encoder_decoder=self.config.is_encoder_decoder,
3221-
)
32223226

32233227
unfinished_sequences = unfinished_sequences & ~stopping_criteria(input_ids, scores)
32243228
this_peer_finished = unfinished_sequences.max() == 0
@@ -3415,9 +3419,15 @@ def _beam_search(
34153419
else: # Unchanged original behavior
34163420
outputs = self(**model_inputs, return_dict=True)
34173421

3422+
# synced_gpus: don't waste resources running the code we don't need; kwargs must be updated before skipping
3423+
model_kwargs = self._update_model_kwargs_for_generation(
3424+
outputs,
3425+
model_kwargs,
3426+
is_encoder_decoder=self.config.is_encoder_decoder,
3427+
)
34183428
if synced_gpus and this_peer_finished:
34193429
cur_len = cur_len + 1
3420-
continue # don't waste resources running the code we don't need
3430+
continue
34213431

34223432
# Clone is needed to avoid keeping a hanging ref to outputs.logits which may be very large for first iteration
34233433
# (the clone itself is always small)
@@ -3491,12 +3501,6 @@ def _beam_search(
34913501

34923502
input_ids = torch.cat([input_ids[beam_idx, :], beam_next_tokens.unsqueeze(-1)], dim=-1)
34933503

3494-
model_kwargs = self._update_model_kwargs_for_generation(
3495-
outputs,
3496-
model_kwargs,
3497-
is_encoder_decoder=self.config.is_encoder_decoder,
3498-
)
3499-
35003504
# This is needed to properly delete outputs.logits which may be very large for first iteration
35013505
# Otherwise a reference to outputs is kept which keeps the logits alive in the next iteration
35023506
# IMPORTANT: Note that this should appear BEFORE the call to _reorder_cache() to save the maximum memory
@@ -3670,9 +3674,15 @@ def _group_beam_search(
36703674

36713675
outputs = self(**model_inputs, return_dict=True)
36723676

3677+
# synced_gpus: don't waste resources running the code we don't need; kwargs must be updated before skipping
3678+
model_kwargs = self._update_model_kwargs_for_generation(
3679+
outputs,
3680+
model_kwargs,
3681+
is_encoder_decoder=self.config.is_encoder_decoder,
3682+
)
36733683
if synced_gpus and this_peer_finished:
36743684
cur_len = cur_len + 1
3675-
continue # don't waste resources running the code we don't need
3685+
continue
36763686

36773687
if output_scores:
36783688
processed_score = torch.zeros_like(outputs.logits[:, -1, :])
@@ -3782,12 +3792,6 @@ def _group_beam_search(
37823792

37833793
input_ids = torch.cat([input_ids, current_tokens.unsqueeze(-1)], dim=-1)
37843794

3785-
model_kwargs = self._update_model_kwargs_for_generation(
3786-
outputs,
3787-
model_kwargs,
3788-
is_encoder_decoder=self.config.is_encoder_decoder,
3789-
)
3790-
37913795
# This is needed to properly delete outputs.logits which may be very large for first iteration
37923796
# Otherwise a reference to outputs is kept which keeps the logits alive in the next iteration
37933797
# IMPORTANT: Note that this should appear BEFORE the call to _reorder_cache() to save the maximum memory
@@ -3948,9 +3952,15 @@ def _constrained_beam_search(
39483952

39493953
outputs = self(**model_inputs, return_dict=True)
39503954

3955+
# synced_gpus: don't waste resources running the code we don't need; kwargs must be updated before skipping
3956+
model_kwargs = self._update_model_kwargs_for_generation(
3957+
outputs,
3958+
model_kwargs,
3959+
is_encoder_decoder=self.config.is_encoder_decoder,
3960+
)
39513961
if synced_gpus and this_peer_finished:
39523962
cur_len = cur_len + 1
3953-
continue # don't waste resources running the code we don't need
3963+
continue
39543964

39553965
# Clone is needed to avoid keeping a hanging ref to outputs.logits which may be very large for first iteration
39563966
# (the clone itself is always small)
@@ -4018,11 +4028,6 @@ def _constrained_beam_search(
40184028
beam_idx = beam_outputs["next_beam_indices"]
40194029

40204030
input_ids = torch.cat([input_ids[beam_idx, :], beam_next_tokens.unsqueeze(-1)], dim=-1)
4021-
model_kwargs = self._update_model_kwargs_for_generation(
4022-
outputs,
4023-
model_kwargs,
4024-
is_encoder_decoder=self.config.is_encoder_decoder,
4025-
)
40264031

40274032
# This is needed to properly delete outputs.logits which may be very large for first iteration
40284033
# Otherwise a reference to outputs is kept which keeps the logits alive in the next iteration
@@ -4162,17 +4167,8 @@ def _assisted_decoding(
41624167
unfinished_sequences = torch.ones(batch_size, dtype=torch.long, device=input_ids.device)
41634168
model_kwargs = self._get_initial_cache_position(input_ids, model_kwargs)
41644169

4165-
# This is needed if return_dict_in_generate is True
4166-
start_from_empty_dynamic_cache = False
4167-
past_key_values = model_kwargs.get("past_key_values", None)
4168-
if isinstance(past_key_values, DynamicCache) or (
4169-
isinstance(past_key_values, EncoderDecoderCache)
4170-
and isinstance(past_key_values.self_attention_cache, DynamicCache)
4171-
):
4172-
if past_key_values.get_seq_length() == 0:
4173-
start_from_empty_dynamic_cache = True
4174-
41754170
this_peer_finished = False
4171+
is_first_iteration = True # to preserve the same API in the output as other generation methods
41764172
while self._has_unfinished_sequences(this_peer_finished, synced_gpus, device=input_ids.device):
41774173
cur_len = input_ids.shape[-1]
41784174

@@ -4271,63 +4267,59 @@ def _assisted_decoding(
42714267
# 5. Update the candidate generation strategy if needed
42724268
candidate_generator.update_candidate_strategy(input_ids, new_logits, n_matches)
42734269

4270+
# synced_gpus: don't waste resources running the code we don't need; kwargs must be updated before skipping
4271+
model_kwargs = self._update_model_kwargs_for_generation(
4272+
outputs,
4273+
model_kwargs,
4274+
is_encoder_decoder=self.config.is_encoder_decoder,
4275+
num_new_tokens=n_matches + 1,
4276+
)
42744277
if synced_gpus and this_peer_finished:
4275-
continue # don't waste resources running the code we don't need
4278+
continue
42764279

42774280
# Store scores, attentions and hidden_states when required
42784281
# Assistant: modified to append one tuple element per token, as in the other generation methods.
42794282
if return_dict_in_generate:
4283+
newly_added_length = n_matches + 1
42804284
if output_scores:
4281-
scores += tuple(new_logits[:, i, :] for i in range(n_matches + 1))
4285+
scores += tuple(new_logits[:, i, :] for i in range(newly_added_length))
42824286
if output_logits:
4283-
raw_logits += (next_token_logits,)
4284-
4285-
if "past_key_values" not in model_kwargs or start_from_empty_dynamic_cache:
4286-
added_len = new_cur_len
4287-
# set it to false for other iterations
4288-
start_from_empty_dynamic_cache = False
4289-
else:
4290-
added_len = n_matches + 1
4287+
raw_logits += tuple(next_token_logits[:, i, :] for i in range(newly_added_length))
42914288

4289+
newly_added_length = new_cur_len if is_first_iteration else newly_added_length
42924290
if output_attentions:
42934291
if self.config.is_encoder_decoder:
42944292
cross_attentions = _split_model_outputs(
4295-
cross_attentions, outputs.cross_attentions, cur_len, added_len
4293+
cross_attentions, outputs.cross_attentions, cur_len, newly_added_length
42964294
)
42974295
decoder_attentions = _split_model_outputs(
42984296
decoder_attentions,
42994297
outputs.decoder_attentions,
43004298
cur_len,
4301-
added_len,
4299+
newly_added_length,
43024300
is_decoder_attention=True,
43034301
)
43044302
else:
43054303
decoder_attentions = _split_model_outputs(
43064304
decoder_attentions,
43074305
outputs.attentions,
43084306
cur_len,
4309-
added_len,
4307+
newly_added_length,
43104308
is_decoder_attention=True,
43114309
)
43124310
if output_hidden_states:
43134311
if self.config.is_encoder_decoder:
43144312
decoder_hidden_states = _split_model_outputs(
4315-
decoder_hidden_states, outputs.decoder_hidden_states, cur_len, added_len
4313+
decoder_hidden_states, outputs.decoder_hidden_states, cur_len, newly_added_length
43164314
)
43174315
else:
43184316
decoder_hidden_states = _split_model_outputs(
4319-
decoder_hidden_states, outputs.hidden_states, cur_len, added_len
4317+
decoder_hidden_states, outputs.hidden_states, cur_len, newly_added_length
43204318
)
43214319

4322-
model_kwargs = self._update_model_kwargs_for_generation(
4323-
outputs,
4324-
model_kwargs,
4325-
is_encoder_decoder=self.config.is_encoder_decoder,
4326-
num_new_tokens=n_matches + 1,
4327-
)
4328-
43294320
unfinished_sequences = unfinished_sequences & ~stopping_criteria(input_ids, scores)
43304321
this_peer_finished = unfinished_sequences.max() == 0
4322+
is_first_iteration = False
43314323

43324324
if streamer is not None:
43334325
streamer.end()

0 commit comments

Comments
 (0)