diff --git a/trl/trainer/sft_trainer.py b/trl/trainer/sft_trainer.py index 9014470b55..bd75be4bdc 100644 --- a/trl/trainer/sft_trainer.py +++ b/trl/trainer/sft_trainer.py @@ -346,36 +346,79 @@ def torch_call(self, examples: list[dict[str, Any]]) -> dict[str, Any]: else: raise KeyError(f"Unexpected input keys in examples: {list(examples[0].keys())}.") + def _has_structured_content(self, messages: list[dict]) -> tuple[bool, bool]: + """ + Check if messages contain structured content with images or videos. + + Returns: + tuple[bool, bool]: (has_image_content, has_video_content) + """ + has_image_content = False + has_video_content = False + + if messages and isinstance(messages, list): + for msg in messages: + if isinstance(msg.get("content"), list): + for item in msg["content"]: + if isinstance(item, dict): + if item.get("type") == "image": + has_image_content = True + elif item.get("type") == "video": + has_video_content = True + if has_image_content and has_video_content: + break + + return has_image_content, has_video_content + def _collate_language_modeling(self, examples: list[dict[str, Any]]) -> dict[str, Any]: - images = [example["images"] for example in examples] - # Transformers requires at least one image in the batch, otherwise it throws an error - if all(img_list == [] for img_list in images): - images = None - - if "messages" in examples[0]: # conversational case - for example in examples: - prepare_multimodal_messages(example["messages"], len(example["images"])) - messages = [example["messages"] for example in examples] - texts = self.processor.apply_chat_template(messages) - elif self.dataset_text_field in examples[0]: # standard case + # Extract images and videos from examples + images = [example.get("images", []) for example in examples] + videos = [example.get("videos", []) for example in examples] + images = None if all(img == [] for img in images) else images + videos = None if all(vid == [] for vid in videos) else videos + + # Apply chat template for conversational data + if "messages" in examples[0]: + messages_list = [example["messages"] for example in examples] + # Check if messages use structured content format ({"type": "image"} or {"type": "video"}) + has_image_content, has_video_content = self._has_structured_content(messages_list[0]) + + # For structured content, pass images/videos to apply_chat_template for extraction + template_kwargs = {} + if has_image_content and images: + template_kwargs["images"] = images + if has_video_content and videos: + template_kwargs["videos"] = videos + texts = self.processor.apply_chat_template(messages_list, **template_kwargs) + elif self.dataset_text_field in examples[0]: texts = [example[self.dataset_text_field] for example in examples] + has_image_content = has_video_content = False else: raise KeyError( - "The input examples must contain either 'messages' for conversational data or 'text' for standard " - "data." + "The input examples must contain either 'messages' for conversational data or 'text' for standard data." ) - output = self.processor( - images=images, - text=texts, - padding=True, - padding_side="right", - pad_to_multiple_of=self.pad_to_multiple_of, - truncation=self.max_length is not None, - max_length=self.max_length, - return_tensors=self.return_tensors, - add_special_tokens=False, # to avoid adding the BOS, twice see https://huggingface.co/blog/qgallouedec/gotchas-in-tokenizer-behavior#7-chat-template-and-tokenization-dont-compose-due-to-special-tokens - ) + # Build processor kwargs + processor_kwargs = { + "text": texts, + "padding": True, + "padding_side": "right", + "pad_to_multiple_of": self.pad_to_multiple_of, + "return_tensors": self.return_tensors, + "add_special_tokens": False, + } + if self.max_length is not None: + processor_kwargs["truncation"] = True + processor_kwargs["max_length"] = self.max_length + + # Add images/videos to processor only if not already in structured content + if images and not has_image_content: + processor_kwargs["images"] = images + if videos and not has_video_content: + processor_kwargs["videos"] = videos + + output = self.processor(**processor_kwargs) + labels = output["input_ids"].clone() labels[output["attention_mask"] == 0] = -100 # We mask only padding tokens (-100) in the labels. Vision tokens are left unchanged because their handling in @@ -390,26 +433,47 @@ def _collate_prompt_completion(self, examples: list[dict[str, Any]]) -> dict[str "Padding to a multiple of a value is not yet implemented for vision-language modeling and " "prompt-completion data yet." ) - images = [example["images"] for example in examples] - # Transformers requires at least one image in the batch, otherwise it throws an error - if all(img_list == [] for img_list in images): - images = None - if is_conversational(examples[0]): # conversational case - for example in examples: - prepare_multimodal_messages(example["prompt"] + example["completion"], len(example["images"])) + # Extract images and videos from examples + images = [example.get("images", []) for example in examples] + videos = [example.get("videos", []) for example in examples] + images = None if all(img == [] for img in images) else images + videos = None if all(vid == [] for vid in videos) else videos + + # Apply chat template for conversational data + if is_conversational(examples[0]): + # Check if messages use structured content format + first_prompt_completion = examples[0]["prompt"] + examples[0]["completion"] + has_image_content, has_video_content = self._has_structured_content(first_prompt_completion) + + # For non-structured content, add image placeholders (videos require structured content) + if not (has_image_content or has_video_content): + for example in examples: + num_images = len(example.get("images", [])) + if num_images > 0 and not example.get("videos"): + prepare_multimodal_messages(example["prompt"] + example["completion"], num_images=num_images) + examples = [apply_chat_template(example, self.processor) for example in examples] + else: + has_image_content = has_video_content = False prompts = [example["prompt"] for example in examples] completions = [example["completion"] for example in examples] - processed_prompts = self.processor( - images=images, - text=prompts, - padding=True, - padding_side="left", - return_tensors=self.return_tensors, - add_special_tokens=False, # to avoid adding the BOS, twice see https://huggingface.co/blog/qgallouedec/gotchas-in-tokenizer-behavior#7-chat-template-and-tokenization-dont-compose-due-to-special-tokens - ) + # Build processor kwargs for prompts + prompt_kwargs = { + "text": prompts, + "padding": True, + "padding_side": "left", + "return_tensors": self.return_tensors, + "add_special_tokens": False, + } + # Add images/videos to processor only if not already in structured content + if images and not has_image_content: + prompt_kwargs["images"] = images + if videos and not has_video_content: + prompt_kwargs["videos"] = videos + + processed_prompts = self.processor(**prompt_kwargs) processed_completions = self.processor( text=completions, padding=True, @@ -738,10 +802,15 @@ def __init__( else: self.completion_only_loss = args.completion_only_loss - self._is_vision_dataset = "image" in dataset_sample or "images" in dataset_sample + self._is_vision_dataset = ( + "image" in dataset_sample + or "images" in dataset_sample + or "video" in dataset_sample + or "videos" in dataset_sample + ) if self._is_vision_dataset and not self._is_vlm: raise ValueError( - "The dataset appears to be vision-related (contains 'image' or 'images' keys), but the provided " + "The dataset appears to be vision-related (contains 'image', 'images', 'video', or 'videos' keys), but the provided " "model does not seem to be a vision-language model. Please check your model and dataset." ) @@ -1073,7 +1142,7 @@ def _set_signature_columns_if_needed(self): # dataset. So we need to override the default signature columns to include "completion_mask" as well. if self._signature_columns is None: if self._is_vision_dataset: - self._signature_columns = ["messages", "prompt", "completion", "images"] + self._signature_columns = ["messages", "prompt", "completion", "images", "videos"] else: self._signature_columns = ["input_ids", "labels", "seq_lengths", "completion_mask", "assistant_masks"]