Skip to content

Commit c2337cf

Browse files
mapmeldclefourrier
andauthored
Allow AdapterModels to have custom tokens (#306)
--------- Co-authored-by: Clémentine Fourrier <[email protected]>
1 parent 3653adf commit c2337cf

File tree

2 files changed

+14
-2
lines changed

2 files changed

+14
-2
lines changed

src/lighteval/main_accelerate.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -169,11 +169,11 @@ def accelerate( # noqa C901
169169
# Keeping only non null params
170170
args_dict = {k: v for k, v in args_dict.items() if v is not None}
171171

172-
if config["merged_weights"]["delta_weights"]:
172+
if config["merged_weights"].get("delta_weights", False):
173173
if config["merged_weights"]["base_model"] is None:
174174
raise ValueError("You need to specify a base model when using delta weights")
175175
model_config = DeltaModelConfig(**args_dict)
176-
elif config["merged_weights"]["adapter_weights"]:
176+
elif config["merged_weights"].get("adapter_weights", False):
177177
if config["merged_weights"]["base_model"] is None:
178178
raise ValueError("You need to specify a base model when using adapter weights")
179179
model_config = AdapterModelConfig(**args_dict)

src/lighteval/models/transformers/adapter_model.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,18 @@ def _create_auto_model(self, config: AdapterModelConfig, env_config: EnvConfig)
8484
base = AutoModelForCausalLM.from_pretrained(
8585
config.base_model, torch_dtype=torch.float16, low_cpu_mem_usage=True, token=env_config.token
8686
)
87+
# resize model for adapters with added tokens
88+
token_diff = len(self._tokenizer) - base.config.vocab_size
89+
if token_diff != 0:
90+
if token_diff > 0:
91+
logger.info(
92+
f"You're using the adapter model's tokenizer, which has more tokens than the base model. Adding {token_diff} token(s)."
93+
)
94+
else:
95+
logger.info(
96+
f"You're using the adapter model's tokenizer, which has fewer tokens than the base model. Removing {abs(token_diff)} token(s)."
97+
)
98+
base.resize_token_embeddings(len(self._tokenizer))
8799
# Should pass revision
88100
model = PeftModel.from_pretrained(base, adapter_weights)
89101
model = model.merge_and_unload()

0 commit comments

Comments
 (0)