Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
153 changes: 111 additions & 42 deletions trl/trainer/sft_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -346,36 +346,79 @@ def torch_call(self, examples: list[dict[str, Any]]) -> dict[str, Any]:
else:
raise KeyError(f"Unexpected input keys in examples: {list(examples[0].keys())}.")

def _has_structured_content(self, messages: list[dict]) -> tuple[bool, bool]:
"""
Check if messages contain structured content with images or videos.

Returns:
tuple[bool, bool]: (has_image_content, has_video_content)
"""
has_image_content = False
has_video_content = False

if messages and isinstance(messages, list):
for msg in messages:
if isinstance(msg.get("content"), list):
for item in msg["content"]:
if isinstance(item, dict):
if item.get("type") == "image":
has_image_content = True
elif item.get("type") == "video":
has_video_content = True
if has_image_content and has_video_content:
break

return has_image_content, has_video_content

def _collate_language_modeling(self, examples: list[dict[str, Any]]) -> dict[str, Any]:
images = [example["images"] for example in examples]
# Transformers requires at least one image in the batch, otherwise it throws an error
if all(img_list == [] for img_list in images):
images = None

if "messages" in examples[0]: # conversational case
for example in examples:
prepare_multimodal_messages(example["messages"], len(example["images"]))
messages = [example["messages"] for example in examples]
texts = self.processor.apply_chat_template(messages)
elif self.dataset_text_field in examples[0]: # standard case
# Extract images and videos from examples
images = [example.get("images", []) for example in examples]
videos = [example.get("videos", []) for example in examples]
images = None if all(img == [] for img in images) else images
videos = None if all(vid == [] for vid in videos) else videos

# Apply chat template for conversational data
if "messages" in examples[0]:
messages_list = [example["messages"] for example in examples]
# Check if messages use structured content format ({"type": "image"} or {"type": "video"})
has_image_content, has_video_content = self._has_structured_content(messages_list[0])

# For structured content, pass images/videos to apply_chat_template for extraction
template_kwargs = {}
if has_image_content and images:
template_kwargs["images"] = images
if has_video_content and videos:
template_kwargs["videos"] = videos
texts = self.processor.apply_chat_template(messages_list, **template_kwargs)
elif self.dataset_text_field in examples[0]:
texts = [example[self.dataset_text_field] for example in examples]
has_image_content = has_video_content = False
else:
raise KeyError(
"The input examples must contain either 'messages' for conversational data or 'text' for standard "
"data."
"The input examples must contain either 'messages' for conversational data or 'text' for standard data."
)

output = self.processor(
images=images,
text=texts,
padding=True,
padding_side="right",
pad_to_multiple_of=self.pad_to_multiple_of,
truncation=self.max_length is not None,
max_length=self.max_length,
return_tensors=self.return_tensors,
add_special_tokens=False, # to avoid adding the BOS, twice see https://huggingface.co/blog/qgallouedec/gotchas-in-tokenizer-behavior#7-chat-template-and-tokenization-dont-compose-due-to-special-tokens
)
# Build processor kwargs
processor_kwargs = {
"text": texts,
"padding": True,
"padding_side": "right",
"pad_to_multiple_of": self.pad_to_multiple_of,
"return_tensors": self.return_tensors,
"add_special_tokens": False,
}
if self.max_length is not None:
processor_kwargs["truncation"] = True
processor_kwargs["max_length"] = self.max_length

# Add images/videos to processor only if not already in structured content
if images and not has_image_content:
processor_kwargs["images"] = images
if videos and not has_video_content:
processor_kwargs["videos"] = videos

output = self.processor(**processor_kwargs)

labels = output["input_ids"].clone()
labels[output["attention_mask"] == 0] = -100
# We mask only padding tokens (-100) in the labels. Vision tokens are left unchanged because their handling in
Expand All @@ -390,26 +433,47 @@ def _collate_prompt_completion(self, examples: list[dict[str, Any]]) -> dict[str
"Padding to a multiple of a value is not yet implemented for vision-language modeling and "
"prompt-completion data yet."
)
images = [example["images"] for example in examples]
# Transformers requires at least one image in the batch, otherwise it throws an error
if all(img_list == [] for img_list in images):
images = None
if is_conversational(examples[0]): # conversational case
for example in examples:
prepare_multimodal_messages(example["prompt"] + example["completion"], len(example["images"]))
# Extract images and videos from examples
images = [example.get("images", []) for example in examples]
videos = [example.get("videos", []) for example in examples]
images = None if all(img == [] for img in images) else images
videos = None if all(vid == [] for vid in videos) else videos

# Apply chat template for conversational data
if is_conversational(examples[0]):
# Check if messages use structured content format
first_prompt_completion = examples[0]["prompt"] + examples[0]["completion"]
has_image_content, has_video_content = self._has_structured_content(first_prompt_completion)

# For non-structured content, add image placeholders (videos require structured content)
if not (has_image_content or has_video_content):
for example in examples:
num_images = len(example.get("images", []))
if num_images > 0 and not example.get("videos"):
prepare_multimodal_messages(example["prompt"] + example["completion"], num_images=num_images)

examples = [apply_chat_template(example, self.processor) for example in examples]
else:
has_image_content = has_video_content = False

prompts = [example["prompt"] for example in examples]
completions = [example["completion"] for example in examples]

processed_prompts = self.processor(
images=images,
text=prompts,
padding=True,
padding_side="left",
return_tensors=self.return_tensors,
add_special_tokens=False, # to avoid adding the BOS, twice see https://huggingface.co/blog/qgallouedec/gotchas-in-tokenizer-behavior#7-chat-template-and-tokenization-dont-compose-due-to-special-tokens
)
# Build processor kwargs for prompts
prompt_kwargs = {
"text": prompts,
"padding": True,
"padding_side": "left",
"return_tensors": self.return_tensors,
"add_special_tokens": False,
}
# Add images/videos to processor only if not already in structured content
if images and not has_image_content:
prompt_kwargs["images"] = images
if videos and not has_video_content:
prompt_kwargs["videos"] = videos

processed_prompts = self.processor(**prompt_kwargs)
processed_completions = self.processor(
text=completions,
padding=True,
Expand Down Expand Up @@ -738,10 +802,15 @@ def __init__(
else:
self.completion_only_loss = args.completion_only_loss

self._is_vision_dataset = "image" in dataset_sample or "images" in dataset_sample
self._is_vision_dataset = (
"image" in dataset_sample
or "images" in dataset_sample
or "video" in dataset_sample
or "videos" in dataset_sample
)
if self._is_vision_dataset and not self._is_vlm:
raise ValueError(
"The dataset appears to be vision-related (contains 'image' or 'images' keys), but the provided "
"The dataset appears to be vision-related (contains 'image', 'images', 'video', or 'videos' keys), but the provided "
"model does not seem to be a vision-language model. Please check your model and dataset."
)

Expand Down Expand Up @@ -1073,7 +1142,7 @@ def _set_signature_columns_if_needed(self):
# dataset. So we need to override the default signature columns to include "completion_mask" as well.
if self._signature_columns is None:
if self._is_vision_dataset:
self._signature_columns = ["messages", "prompt", "completion", "images"]
self._signature_columns = ["messages", "prompt", "completion", "images", "videos"]
else:
self._signature_columns = ["input_ids", "labels", "seq_lengths", "completion_mask", "assistant_masks"]

Expand Down
Loading