Skip to content

Commit b997a31

Browse files
kashifalbertvillanovaqgallouedec
authored
[Online-DPO] fix the completion_len == max_new_tokens crash (#4193)
Co-authored-by: Albert Villanova del Moral <[email protected]> Co-authored-by: Quentin Gallouédec <[email protected]>
1 parent 86d1963 commit b997a31

File tree

2 files changed

+35
-24
lines changed

2 files changed

+35
-24
lines changed

trl/trainer/online_dpo_config.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -412,3 +412,10 @@ def __post_init__(self):
412412

413413
if hasattr(self.beta, "__len__") and len(self.beta) == 1:
414414
self.beta = self.beta[0]
415+
416+
if self.max_new_tokens >= self.max_length:
417+
warnings.warn(
418+
f"The configuration has `max_new_tokens` ({self.max_new_tokens}) >= `max_length` ({self.max_length}). "
419+
"This will cause prompts to be truncated or completely removed in the forward pass. "
420+
"To preserve prompts, ensure e.g. `max_length > max_new_tokens + 512`. ",
421+
)

trl/trainer/online_dpo_trainer.py

Lines changed: 28 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -57,8 +57,13 @@
5757
from ..extras.profiling import profiling_context
5858
from ..extras.vllm_client import VLLMClient
5959
from ..import_utils import is_vllm_available
60-
from ..models import create_reference_model, prepare_peft_model
61-
from ..models.utils import unwrap_model_for_generation
60+
from ..models import (
61+
create_reference_model,
62+
prepare_deepspeed,
63+
prepare_fsdp,
64+
prepare_peft_model,
65+
unwrap_model_for_generation,
66+
)
6267
from .base_trainer import BaseTrainer
6368
from .judges import BasePairwiseJudge
6469
from .online_dpo_config import OnlineDPOConfig
@@ -69,7 +74,6 @@
6974
empty_cache,
7075
ensure_master_addr_port,
7176
pad,
72-
prepare_deepspeed,
7377
truncate_right,
7478
)
7579

@@ -588,24 +592,20 @@ def __init__(
588592
generation_kwargs = {k: v for k, v in generation_kwargs.items() if v is not None}
589593
self.generation_config = GenerationConfig(**generation_kwargs)
590594

591-
if self.is_deepspeed_enabled:
592-
if self.ref_model is not None:
593-
self.ref_model = prepare_deepspeed(
594-
self.ref_model, args.per_device_train_batch_size, args.fp16, args.bf16
595-
)
596-
# Prepare reward function models for DeepSpeed
597-
if self.reward_funcs is not None:
598-
for i, reward_func in enumerate(self.reward_funcs):
599-
if isinstance(reward_func, PreTrainedModel):
595+
if self.ref_model is not None:
596+
if self.is_deepspeed_enabled:
597+
self.ref_model = prepare_deepspeed(self.ref_model, self.accelerator)
598+
elif self.is_fsdp_enabled:
599+
self.ref_model = prepare_fsdp(self.ref_model, self.accelerator)
600+
else:
601+
self.ref_model = self.accelerator.prepare_model(self.ref_model, evaluation_mode=True)
602+
if self.reward_funcs is not None:
603+
for i, reward_func in enumerate(self.reward_funcs):
604+
if isinstance(reward_func, PreTrainedModel):
605+
if self.is_deepspeed_enabled:
600606
self.reward_funcs[i] = prepare_deepspeed(reward_func, self.accelerator)
601-
else:
602-
if self.ref_model is not None:
603-
self.ref_model = self.ref_model.to(self.accelerator.device)
604-
# Prepare reward function models for FSDP/regular training
605-
if self.reward_funcs is not None:
606-
for i, reward_func in enumerate(self.reward_funcs):
607-
if isinstance(reward_func, PreTrainedModel):
608-
# Set device placement to True to make `prepare_model` move `reward_func` to device when using fsdp
607+
else:
608+
# set device placement to True to make `prepare_model` move `reward_func` to device when using fsdp
609609
self.reward_funcs[i] = self.accelerator.prepare_model(
610610
reward_func, evaluation_mode=True, device_placement=True
611611
)
@@ -833,8 +833,10 @@ def _generate_vllm_server(self, prompts, images=None):
833833

834834
def _generate_vllm_colocate(self, prompts, images=None):
835835
"""Generate completions using vLLM colocate mode"""
836-
# Update model weights if needed
837-
self._move_model_to_vllm()
836+
# Update model weights if needed - only after gradient accumulation completes
837+
if self.state.global_step != self._last_loaded_step:
838+
self._move_model_to_vllm()
839+
self._last_loaded_step = self.state.global_step
838840

839841
# Apply chat template if conversational
840842
if is_conversational({"prompt": prompts[0]}):
@@ -1234,10 +1236,12 @@ def _forward(self, model, prompt_ids, prompt_mask, completion_ids, completion_ma
12341236
# Get the logprobs of the completions from the model
12351237
output = model(prompt_completion_ids, **model_kwargs)
12361238

1237-
# There is 1 offset, because the model predict the next token
1239+
# There is 1 offset, because the model predicts the next token
12381240
prompt_len = prompt_ids.size(1)
12391241
start_idx = prompt_len - 1 if prompt_len > 0 else 0
1240-
logits = output.logits[:, start_idx:-1]
1242+
# Only slice off the last logit when we have a prompt, otherwise we need all logits
1243+
end_idx = -1 if prompt_len > 0 else None
1244+
logits = output.logits[:, start_idx:end_idx]
12411245

12421246
# Take the completion tokens logprob
12431247
logprobs = torch.take_along_dim(logits.log_softmax(dim=-1), completion_ids.unsqueeze(-1), dim=2).squeeze(-1)

0 commit comments

Comments
 (0)