|
57 | 57 | from ..extras.profiling import profiling_context |
58 | 58 | from ..extras.vllm_client import VLLMClient |
59 | 59 | from ..import_utils import is_vllm_available |
60 | | -from ..models import create_reference_model, prepare_peft_model |
61 | | -from ..models.utils import unwrap_model_for_generation |
| 60 | +from ..models import ( |
| 61 | + create_reference_model, |
| 62 | + prepare_deepspeed, |
| 63 | + prepare_fsdp, |
| 64 | + prepare_peft_model, |
| 65 | + unwrap_model_for_generation, |
| 66 | +) |
62 | 67 | from .base_trainer import BaseTrainer |
63 | 68 | from .judges import BasePairwiseJudge |
64 | 69 | from .online_dpo_config import OnlineDPOConfig |
|
69 | 74 | empty_cache, |
70 | 75 | ensure_master_addr_port, |
71 | 76 | pad, |
72 | | - prepare_deepspeed, |
73 | 77 | truncate_right, |
74 | 78 | ) |
75 | 79 |
|
@@ -588,24 +592,20 @@ def __init__( |
588 | 592 | generation_kwargs = {k: v for k, v in generation_kwargs.items() if v is not None} |
589 | 593 | self.generation_config = GenerationConfig(**generation_kwargs) |
590 | 594 |
|
591 | | - if self.is_deepspeed_enabled: |
592 | | - if self.ref_model is not None: |
593 | | - self.ref_model = prepare_deepspeed( |
594 | | - self.ref_model, args.per_device_train_batch_size, args.fp16, args.bf16 |
595 | | - ) |
596 | | - # Prepare reward function models for DeepSpeed |
597 | | - if self.reward_funcs is not None: |
598 | | - for i, reward_func in enumerate(self.reward_funcs): |
599 | | - if isinstance(reward_func, PreTrainedModel): |
| 595 | + if self.ref_model is not None: |
| 596 | + if self.is_deepspeed_enabled: |
| 597 | + self.ref_model = prepare_deepspeed(self.ref_model, self.accelerator) |
| 598 | + elif self.is_fsdp_enabled: |
| 599 | + self.ref_model = prepare_fsdp(self.ref_model, self.accelerator) |
| 600 | + else: |
| 601 | + self.ref_model = self.accelerator.prepare_model(self.ref_model, evaluation_mode=True) |
| 602 | + if self.reward_funcs is not None: |
| 603 | + for i, reward_func in enumerate(self.reward_funcs): |
| 604 | + if isinstance(reward_func, PreTrainedModel): |
| 605 | + if self.is_deepspeed_enabled: |
600 | 606 | self.reward_funcs[i] = prepare_deepspeed(reward_func, self.accelerator) |
601 | | - else: |
602 | | - if self.ref_model is not None: |
603 | | - self.ref_model = self.ref_model.to(self.accelerator.device) |
604 | | - # Prepare reward function models for FSDP/regular training |
605 | | - if self.reward_funcs is not None: |
606 | | - for i, reward_func in enumerate(self.reward_funcs): |
607 | | - if isinstance(reward_func, PreTrainedModel): |
608 | | - # Set device placement to True to make `prepare_model` move `reward_func` to device when using fsdp |
| 607 | + else: |
| 608 | + # set device placement to True to make `prepare_model` move `reward_func` to device when using fsdp |
609 | 609 | self.reward_funcs[i] = self.accelerator.prepare_model( |
610 | 610 | reward_func, evaluation_mode=True, device_placement=True |
611 | 611 | ) |
@@ -833,8 +833,10 @@ def _generate_vllm_server(self, prompts, images=None): |
833 | 833 |
|
834 | 834 | def _generate_vllm_colocate(self, prompts, images=None): |
835 | 835 | """Generate completions using vLLM colocate mode""" |
836 | | - # Update model weights if needed |
837 | | - self._move_model_to_vllm() |
| 836 | + # Update model weights if needed - only after gradient accumulation completes |
| 837 | + if self.state.global_step != self._last_loaded_step: |
| 838 | + self._move_model_to_vllm() |
| 839 | + self._last_loaded_step = self.state.global_step |
838 | 840 |
|
839 | 841 | # Apply chat template if conversational |
840 | 842 | if is_conversational({"prompt": prompts[0]}): |
@@ -1234,10 +1236,12 @@ def _forward(self, model, prompt_ids, prompt_mask, completion_ids, completion_ma |
1234 | 1236 | # Get the logprobs of the completions from the model |
1235 | 1237 | output = model(prompt_completion_ids, **model_kwargs) |
1236 | 1238 |
|
1237 | | - # There is 1 offset, because the model predict the next token |
| 1239 | + # There is 1 offset, because the model predicts the next token |
1238 | 1240 | prompt_len = prompt_ids.size(1) |
1239 | 1241 | start_idx = prompt_len - 1 if prompt_len > 0 else 0 |
1240 | | - logits = output.logits[:, start_idx:-1] |
| 1242 | + # Only slice off the last logit when we have a prompt, otherwise we need all logits |
| 1243 | + end_idx = -1 if prompt_len > 0 else None |
| 1244 | + logits = output.logits[:, start_idx:end_idx] |
1241 | 1245 |
|
1242 | 1246 | # Take the completion tokens logprob |
1243 | 1247 | logprobs = torch.take_along_dim(logits.log_softmax(dim=-1), completion_ids.unsqueeze(-1), dim=2).squeeze(-1) |
|
0 commit comments