Skip to content

Commit 7ad9ce8

Browse files
Remove tokenizer creation from sft example script (#4197)
1 parent 0c2dc14 commit 7ad9ce8

File tree

1 file changed

+2
-8
lines changed

1 file changed

+2
-8
lines changed

trl/scripts/sft.py

Lines changed: 2 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,7 @@
6868

6969
from accelerate import logging
7070
from datasets import load_dataset
71-
from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer
71+
from transformers import AutoConfig, AutoModelForCausalLM
7272
from transformers.models.auto.modeling_auto import MODEL_FOR_IMAGE_TEXT_TO_TEXT_MAPPING_NAMES
7373

7474
from trl import (
@@ -93,7 +93,7 @@
9393

9494
def main(script_args, training_args, model_args, dataset_args):
9595
################
96-
# Model init kwargs & Tokenizer
96+
# Model init kwargs
9797
################
9898
model_kwargs = dict(
9999
revision=model_args.model_revision,
@@ -118,11 +118,6 @@ def main(script_args, training_args, model_args, dataset_args):
118118
else:
119119
model = AutoModelForCausalLM.from_pretrained(model_args.model_name_or_path, **model_kwargs)
120120

121-
# Create tokenizer
122-
tokenizer = AutoTokenizer.from_pretrained(
123-
model_args.model_name_or_path, trust_remote_code=model_args.trust_remote_code, use_fast=True
124-
)
125-
126121
# Load the dataset
127122
if dataset_args.datasets and script_args.dataset_name:
128123
logger.warning(
@@ -145,7 +140,6 @@ def main(script_args, training_args, model_args, dataset_args):
145140
args=training_args,
146141
train_dataset=dataset[script_args.dataset_train_split],
147142
eval_dataset=dataset[script_args.dataset_test_split] if training_args.eval_strategy != "no" else None,
148-
processing_class=tokenizer,
149143
peft_config=get_peft_config(model_args),
150144
)
151145

0 commit comments

Comments
 (0)