Skip to content

Commit 96429e7

Browse files
authored
Add support for GGUF Phi-3 (#31844)
* Update docs for GGUF supported models * Add tensor mappings and define class GGUFPhi3Converter * Fix tokenizer * Working version * Attempt to fix some CI failures * Run ruff format * Add vocab, merges, decoder methods like LlamaConverter * Resolve conflicts since Qwen2Moe was added to gguf - I missed one place when resolving conflict - I also made a mistake with tests_ggml.py and now has been fixed to reflect its master version.
1 parent 8e8e7d8 commit 96429e7

File tree

5 files changed

+122
-1
lines changed

5 files changed

+122
-1
lines changed

docs/source/en/gguf.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,7 @@ For now the supported model architectures are the architectures that have been v
7979
- Mistral
8080
- Qwen2
8181
- Qwen2Moe
82+
- Phi3
8283

8384
## Example usage
8485

src/transformers/convert_slow_tokenizer.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1575,6 +1575,7 @@ def converted(self) -> Tokenizer:
15751575
"LlamaTokenizer": LlamaConverter,
15761576
"CodeLlamaTokenizer": LlamaConverter,
15771577
"GemmaTokenizer": GemmaConvert,
1578+
"Phi3Tokenizer": LlamaConverter,
15781579
}
15791580

15801581

src/transformers/convert_slow_tokenizers_checkpoints_to_fast.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,11 @@
2828
logger = logging.get_logger(__name__)
2929

3030

31-
TOKENIZER_CLASSES = {name: getattr(transformers, name + "Fast") for name in SLOW_TO_FAST_CONVERTERS}
31+
TOKENIZER_CLASSES = {
32+
# Phi3 uses Llama tokenizer
33+
name: getattr(transformers, "LlamaTokenizerFast" if name == "Phi3Tokenizer" else name + "Fast")
34+
for name in SLOW_TO_FAST_CONVERTERS
35+
}
3236

3337

3438
def convert_slow_checkpoint_to_fast(tokenizer_name, checkpoint_name, dump_path, force_download):

src/transformers/integrations/ggml.py

Lines changed: 101 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -94,6 +94,19 @@
9494
"output.weight": "lm_head.weight",
9595
"output_norm": "model.norm",
9696
},
97+
"phi3": {
98+
"token_embd": "model.embed_tokens",
99+
"blk": "model.layers",
100+
"ffn_up": "mlp.gate_up_proj",
101+
"ffn_down": "mlp.down_proj",
102+
"ffn_gate": "mlp.gate_up_proj",
103+
"ffn_norm": "post_attention_layernorm",
104+
"attn_norm": "input_layernorm",
105+
"attn_qkv": "self_attn.qkv_proj",
106+
"attn_output": "self_attn.o_proj",
107+
"output.weight": "lm_head.weight",
108+
"output_norm": "model.norm",
109+
},
97110
}
98111

99112

@@ -156,6 +169,18 @@
156169
"ggml.unknown_token_id": "unk_token_id",
157170
"ggml.padding_token_id": "pad_token_id",
158171
},
172+
"phi3": {
173+
"context_length": "max_position_embeddings",
174+
"block_count": "num_hidden_layers",
175+
"feed_forward_length": "intermediate_size",
176+
"embedding_length": "hidden_size",
177+
"rope.dimension_count": None,
178+
"rope.freq_base": "rope_theta",
179+
"attention.head_count": "num_attention_heads",
180+
"attention.head_count_kv": "num_key_value_heads",
181+
"attention.layer_norm_rms_epsilon": "rms_norm_eps",
182+
"vocab_size": "vocab_size",
183+
},
159184
}
160185

161186
GGUF_TOKENIZER_MAPPING = {
@@ -390,10 +415,86 @@ def converted(self) -> Tokenizer:
390415
return tokenizer
391416

392417

418+
class GGUFPhi3Converter(LlamaConverter):
419+
def __init__(self, tokenizer_dict):
420+
self.proto = GGUFTokenizerSkeleton(tokenizer_dict)
421+
self.original_tokenizer = self.proto
422+
self.additional_kwargs = {}
423+
424+
def vocab(self, proto):
425+
return list(zip(proto.tokens, proto.scores))
426+
427+
def merges(self, proto):
428+
return proto.merges
429+
430+
def tokenizer(self, proto):
431+
vocab_scores = self.vocab(self.proto)
432+
merges = self.merges(self.proto)
433+
bpe_vocab = {word: i for i, (word, _score) in enumerate(vocab_scores)}
434+
435+
tokenizer = Tokenizer(BPE(bpe_vocab, merges))
436+
# add the special tokens from phi3 tokenizer config
437+
tokenizer.add_special_tokens(
438+
[
439+
AddedToken("</s>", rstrip=True, lstrip=False, normalized=False, special=True),
440+
AddedToken("<|endoftext|>", normalized=False, special=True),
441+
AddedToken("<|assistant|>", rstrip=True, normalized=False, special=True),
442+
AddedToken("<|placeholder1|>", rstrip=True, normalized=False, special=True),
443+
AddedToken("<|placeholder2|>", rstrip=True, normalized=False, special=True),
444+
AddedToken("<|placeholder3|>", rstrip=True, normalized=False, special=True),
445+
AddedToken("<|placeholder4|>", rstrip=True, normalized=False, special=True),
446+
AddedToken("<|system|>", rstrip=True, normalized=False, special=True),
447+
AddedToken("<|end|>", rstrip=True, normalized=False, special=True),
448+
AddedToken("<|placeholder5|>", rstrip=True, normalized=False, special=True),
449+
AddedToken("<|placeholder6|>", rstrip=True, normalized=False, special=True),
450+
AddedToken("<|user|>", rstrip=True, normalized=False, special=True),
451+
]
452+
)
453+
454+
self.additional_kwargs["unk_token"] = (
455+
proto.tokens[proto.unk_token_id] if proto.unk_token_id is not None else None
456+
)
457+
self.additional_kwargs["eos_token"] = (
458+
proto.tokens[proto.eos_token_id] if proto.eos_token_id is not None else None
459+
)
460+
self.additional_kwargs["bos_token"] = (
461+
proto.tokens[proto.bos_token_id] if proto.bos_token_id is not None else None
462+
)
463+
self.additional_kwargs["pad_token"] = (
464+
proto.tokens[proto.pad_token_id] if proto.pad_token_id is not None else None
465+
)
466+
467+
return tokenizer
468+
469+
def decoder(self, replacement, add_prefix_space):
470+
sequence = [
471+
decoders.ByteFallback(),
472+
decoders.Fuse(),
473+
decoders.Replace(replacement, " "),
474+
]
475+
476+
if add_prefix_space:
477+
sequence += [decoders.Strip(content=" ", left=1)]
478+
return decoders.Sequence(sequence)
479+
480+
def converted(self) -> Tokenizer:
481+
tokenizer = self.tokenizer(self.proto)
482+
483+
replacement = "▁"
484+
add_prefix_space = True
485+
if hasattr(self.original_tokenizer, "add_prefix_space"):
486+
add_prefix_space = self.original_tokenizer.add_prefix_space
487+
488+
tokenizer.decoder = self.decoder(replacement, add_prefix_space)
489+
490+
return tokenizer
491+
492+
393493
GGUF_TO_FAST_CONVERTERS = {
394494
"llama": GGUFLlamaConverter,
395495
"qwen2": GGUFQwen2Converter,
396496
"qwen2_moe": GGUFQwen2Converter,
497+
"phi3": GGUFPhi3Converter,
397498
}
398499

399500

tests/quantization/ggml/test_ggml.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@ class GgufIntegrationTests(unittest.TestCase):
4141
qwen2_moe_model_id = "RichardErkhov/Qwen_-_Qwen1.5-MoE-A2.7B-Chat-gguf"
4242
llama3_model_id = "NousResearch/Meta-Llama-3-8B-GGUF"
4343
tinyllama_model_id = "PenutChen/TinyLlama-1.1B-Chat-v1.0-GGUF"
44+
phi3_model_id = "microsoft/Phi-3-mini-4k-instruct-gguf"
4445

4546
# standard quants
4647
q4_0_gguf_model_id = "tinyllama-1.1b-chat-v1.0.Q4_0.gguf"
@@ -63,6 +64,7 @@ class GgufIntegrationTests(unittest.TestCase):
6364
iq4_xs_gguf_model_id = "TinyLlama-1.1B-Chat-v1.0-IQ4_XS.gguf"
6465
iq4_nl_gguf_model_id = "TinyLlama-1.1B-Chat-v1.0-IQ4_NL.gguf"
6566

67+
q4_0_phi3_model_id = "Phi-3-mini-4k-instruct-q4.gguf"
6668
q4_0_mistral_model_id = "mistral-7b-instruct-v0.2.Q4_0.gguf"
6769
q4_0_qwen2_model_id = "qwen1_5-0_5b-chat-q4_0.gguf"
6870
q4_0_qwen2_moe_model_id = "Qwen1.5-MoE-A2.7B-Chat.Q4_0.gguf"
@@ -347,6 +349,18 @@ def test_qwen2_moe_q4_0(self):
347349
EXPECTED_TEXT = "Hello everyone, I'm a newbie here and would like"
348350
self.assertEqual(tokenizer.decode(out[0], skip_special_tokens=True), EXPECTED_TEXT)
349351

352+
def test_phi3_q4_0(self):
353+
tokenizer = AutoTokenizer.from_pretrained(self.phi3_model_id, gguf_file=self.q4_0_phi3_model_id)
354+
model = AutoModelForCausalLM.from_pretrained(
355+
self.phi3_model_id, gguf_file=self.q4_0_phi3_model_id, device_map="auto", torch_dtype=torch.float16
356+
)
357+
358+
text = tokenizer(self.example_text, return_tensors="pt").to(torch_device)
359+
out = model.generate(**text, max_new_tokens=10)
360+
361+
EXPECTED_TEXT = "Hello, I've been reading about the impact of"
362+
self.assertEqual(tokenizer.decode(out[0], skip_special_tokens=True), EXPECTED_TEXT)
363+
350364
def test_llama3_q4_0_tokenizer(self):
351365
tokenizer = AutoTokenizer.from_pretrained(self.llama3_model_id, gguf_file=self.q4_llama3_model_id)
352366
with tempfile.TemporaryDirectory() as tmpdirname:

0 commit comments

Comments
 (0)