Skip to content
Open
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
163 changes: 135 additions & 28 deletions trl/trainer/sft_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

import contextlib
import os
import re
from collections import defaultdict
from dataclasses import dataclass
from pathlib import Path
Expand Down Expand Up @@ -80,6 +81,60 @@ def get_dataset_column_names(dataset: Union[Dataset, IterableDataset]) -> list[s
return list(next(iter(dataset)).keys()) if dataset.column_names is None else dataset.column_names


def convert_to_structured_content(messages: list[dict[str, Any]], images: list, videos: list) -> list[dict[str, Any]]:
"""
Convert messages with <image> and <video> placeholder tags to structured content format.

This format is required by some VLM processors (like Qwen) that expect typed content objects rather than plain text
with placeholder tags.

Args:
messages: List of message dicts with role and content
images: List of image paths/objects corresponding to <image> tags
videos: List of video paths/objects corresponding to <video> tags

Returns:
List of messages with structured content format

Example:
Input: {"role": "user", "content": "<video>\nWhat's happening?"} Output: {"role": "user", "content": [
{"type": "video", "video": "/path/to/video.mp4"}, {"type": "text", "text": "What's happening?"}
]}
"""
structured_messages = []
image_idx = 0
video_idx = 0

for msg in messages:
role = msg["role"]
content_str = msg["content"]

# Check if this message contains media placeholders
if "<video>" in content_str or "<image>" in content_str:
# Parse placeholders and create structured content
content = []
parts = re.split(r"(<video>|<image>)", content_str)

for part in parts:
if part == "<video>":
if video_idx < len(videos):
content.append({"type": "video", "video": videos[video_idx]})
video_idx += 1
elif part == "<image>":
if image_idx < len(images):
content.append({"type": "image", "image": images[image_idx]})
image_idx += 1
elif part.strip():
content.append({"type": "text", "text": part.strip()})

structured_messages.append({"role": role, "content": content})
else:
# No media placeholders - keep as plain text
structured_messages.append({"role": role, "content": content_str})

return structured_messages


@dataclass
class DataCollatorForLanguageModeling(DataCollatorMixin):
"""
Expand Down Expand Up @@ -347,16 +402,35 @@ def torch_call(self, examples: list[dict[str, Any]]) -> dict[str, Any]:
raise KeyError(f"Unexpected input keys in examples: {list(examples[0].keys())}.")

def _collate_language_modeling(self, examples: list[dict[str, Any]]) -> dict[str, Any]:
images = [example["images"] for example in examples]
# Handle images
images = [example.get("images", []) for example in examples]
# Transformers requires at least one image in the batch, otherwise it throws an error
if all(img_list == [] for img_list in images):
images = None

# Handle videos
videos = [example.get("videos", []) for example in examples]
if all(vid_list == [] for vid_list in videos):
videos = None

if "messages" in examples[0]: # conversational case
messages_list = []
for example in examples:
prepare_multimodal_messages(example["messages"], len(example["images"]))
messages = [example["messages"] for example in examples]
texts = self.processor.apply_chat_template(messages)
num_images = len(example.get("images", []))
num_videos = len(example.get("videos", []))

# Use structured content format when we have any media (images or videos)
# This format works for processors like Qwen that expect typed content objects
if num_videos > 0 or num_images > 0:
structured_messages = convert_to_structured_content(
example["messages"], example.get("images", []), example.get("videos", [])
)
messages_list.append(structured_messages)
else:
# No media - keep original messages
messages_list.append(example["messages"])

texts = self.processor.apply_chat_template(messages_list)
elif self.dataset_text_field in examples[0]: # standard case
texts = [example[self.dataset_text_field] for example in examples]
else:
Expand All @@ -365,17 +439,28 @@ def _collate_language_modeling(self, examples: list[dict[str, Any]]) -> dict[str
"data."
)

output = self.processor(
images=images,
text=texts,
padding=True,
padding_side="right",
pad_to_multiple_of=self.pad_to_multiple_of,
truncation=self.max_length is not None,
max_length=self.max_length,
return_tensors=self.return_tensors,
add_special_tokens=False, # to avoid adding the BOS, twice see https://huggingface.co/blog/qgallouedec/gotchas-in-tokenizer-behavior#7-chat-template-and-tokenization-dont-compose-due-to-special-tokens
)
# Process with both images and videos
processor_kwargs = {
"text": texts,
"padding": True,
"padding_side": "right",
"pad_to_multiple_of": self.pad_to_multiple_of,
"return_tensors": self.return_tensors,
"add_special_tokens": False, # to avoid adding the BOS, twice see https://huggingface.co/blog/qgallouedec/gotchas-in-tokenizer-behavior#7-chat-template-and-tokenization-dont-compose-due-to-special-tokens
}
# Pass truncation parameters to processor if max_length is set
# The processor will handle truncation appropriately for both images and videos
if self.max_length is not None:
processor_kwargs["truncation"] = True
processor_kwargs["max_length"] = self.max_length

if images is not None:
processor_kwargs["images"] = images
if videos is not None:
processor_kwargs["videos"] = videos

output = self.processor(**processor_kwargs)

labels = output["input_ids"].clone()
labels[output["attention_mask"] == 0] = -100
# We mask only padding tokens (-100) in the labels. Vision tokens are left unchanged because their handling in
Expand All @@ -390,26 +475,43 @@ def _collate_prompt_completion(self, examples: list[dict[str, Any]]) -> dict[str
"Padding to a multiple of a value is not yet implemented for vision-language modeling and "
"prompt-completion data yet."
)
images = [example["images"] for example in examples]
# Handle images
images = [example.get("images", []) for example in examples]
# Transformers requires at least one image in the batch, otherwise it throws an error
if all(img_list == [] for img_list in images):
images = None

# Handle videos
videos = [example.get("videos", []) for example in examples]
if all(vid_list == [] for vid_list in videos):
videos = None

if is_conversational(examples[0]): # conversational case
for example in examples:
prepare_multimodal_messages(example["prompt"] + example["completion"], len(example["images"]))
num_images = len(example.get("images", []))
num_videos = len(example.get("videos", []))
# Only prepare multimodal messages for images; videos use native <video> tags
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

videos use native

this doesn't make sense to me. For _collate_language_modeling we need to add {"type": "video"}, but not here?

if num_images > 0 and num_videos == 0:
prepare_multimodal_messages(example["prompt"] + example["completion"], num_images=num_images)
examples = [apply_chat_template(example, self.processor) for example in examples]

prompts = [example["prompt"] for example in examples]
completions = [example["completion"] for example in examples]

processed_prompts = self.processor(
images=images,
text=prompts,
padding=True,
padding_side="left",
return_tensors=self.return_tensors,
add_special_tokens=False, # to avoid adding the BOS, twice see https://huggingface.co/blog/qgallouedec/gotchas-in-tokenizer-behavior#7-chat-template-and-tokenization-dont-compose-due-to-special-tokens
)
# Build processor kwargs for prompts
prompt_kwargs = {
"text": prompts,
"padding": True,
"padding_side": "left",
"return_tensors": self.return_tensors,
"add_special_tokens": False, # to avoid adding the BOS, twice see https://huggingface.co/blog/qgallouedec/gotchas-in-tokenizer-behavior#7-chat-template-and-tokenization-dont-compose-due-to-special-tokens
}
if images is not None:
prompt_kwargs["images"] = images
if videos is not None:
prompt_kwargs["videos"] = videos

processed_prompts = self.processor(**prompt_kwargs)
processed_completions = self.processor(
text=completions,
padding=True,
Expand Down Expand Up @@ -738,10 +840,15 @@ def __init__(
else:
self.completion_only_loss = args.completion_only_loss

self._is_vision_dataset = "image" in dataset_sample or "images" in dataset_sample
self._is_vision_dataset = (
"image" in dataset_sample
or "images" in dataset_sample
or "video" in dataset_sample
or "videos" in dataset_sample
)
if self._is_vision_dataset and not self._is_vlm:
raise ValueError(
"The dataset appears to be vision-related (contains 'image' or 'images' keys), but the provided "
"The dataset appears to be vision-related (contains 'image', 'images', 'video', or 'videos' keys), but the provided "
"model does not seem to be a vision-language model. Please check your model and dataset."
)

Expand Down Expand Up @@ -1073,7 +1180,7 @@ def _set_signature_columns_if_needed(self):
# dataset. So we need to override the default signature columns to include "completion_mask" as well.
if self._signature_columns is None:
if self._is_vision_dataset:
self._signature_columns = ["messages", "prompt", "completion", "images"]
self._signature_columns = ["messages", "prompt", "completion", "images", "videos"]
else:
self._signature_columns = ["input_ids", "labels", "seq_lengths", "completion_mask", "assistant_masks"]

Expand Down
Loading