Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/source/en/gguf.md
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,7 @@ For now the supported model architectures are the architectures that have been v
- Phi3
- Bloom
- Falcon
- StableLM

## Example usage

Expand Down
33 changes: 30 additions & 3 deletions src/transformers/integrations/ggml.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,21 @@
".output.": ".lm_head.",
"output_norm": "ln_f",
},
"stablelm": {
"token_embd": "model.embed_tokens",
"blk": "model.layers",
"ffn_up": "mlp.up_proj",
"ffn_down": "mlp.down_proj",
"ffn_gate": "mlp.gate_proj",
"ffn_norm": "post_attention_layernorm",
"attn_norm": "input_layernorm",
"attn_q": "self_attn.q_proj",
"attn_v": "self_attn.v_proj",
"attn_k": "self_attn.k_proj",
"attn_output": "self_attn.o_proj",
"output.weight": "lm_head.weight",
"output_norm": "model.norm",
},
}


Expand Down Expand Up @@ -238,6 +253,17 @@
"vocab_size": "vocab_size",
"attention.layer_norm_epsilon": "layer_norm_epsilon",
},
"stablelm": {
"context_length": "max_position_embeddings",
"block_count": "num_hidden_layers",
"feed_forward_length": "intermediate_size",
"embedding_length": "hidden_size",
"rope.dimension_count": None,
"attention.head_count": "num_attention_heads",
"attention.head_count_kv": "num_key_value_heads",
"attention.layer_norm_epsilon": "layer_norm_eps",
"vocab_size": "vocab_size",
},
}

GGUF_TOKENIZER_MAPPING = {
Expand Down Expand Up @@ -547,7 +573,7 @@ def converted(self) -> Tokenizer:
return tokenizer


class GGUFBloomConverter(GPT2Converter):
class GGUFGPTConverter(GPT2Converter):
def __init__(self, tokenizer_dict):
self.original_tokenizer = GGUFTokenizerSkeleton(tokenizer_dict)
self.additional_kwargs = {}
Expand All @@ -564,8 +590,9 @@ def converted(self) -> Tokenizer:
"qwen2": GGUFQwen2Converter,
"qwen2_moe": GGUFQwen2Converter,
"phi3": GGUFPhi3Converter,
"bloom": GGUFBloomConverter,
"falcon": GGUFBloomConverter,
"bloom": GGUFGPTConverter,
"falcon": GGUFGPTConverter,
"stablelm": GGUFGPTConverter,
}


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -105,8 +105,8 @@ def __init__(
**kwargs,
):
super().__init__(
vocab_file,
merges_file,
vocab_file=vocab_file,
merges_file=merges_file,
tokenizer_file=tokenizer_file,
unk_token=unk_token,
bos_token=bos_token,
Expand Down
74 changes: 74 additions & 0 deletions tests/quantization/ggml/test_ggml.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,9 @@ class GgufIntegrationTests(unittest.TestCase):
falcon7b_model_id = "xaviviro/falcon-7b-quantized-gguf"
falcon40b_model_id = "maddes8cht/tiiuae-falcon-40b-gguf"
original_flacon7b_model_id = "tiiuae/falcon-7b"
stablelm_model_id = "afrideva/stablelm-3b-4e1t-GGUF"
stablelm2_model_id = "afrideva/stablelm-2-1_6b-GGUF"
original_stablelm2_model_id = "stabilityai/stablelm-2-1_6b"

# standard quants
q4_0_gguf_model_id = "tinyllama-1.1b-chat-v1.0.Q4_0.gguf"
Expand All @@ -58,6 +61,7 @@ class GgufIntegrationTests(unittest.TestCase):
q4_k_gguf_model_id = "tinyllama-1.1b-chat-v1.0.Q4_K_M.gguf"
q5_k_gguf_model_id = "tinyllama-1.1b-chat-v1.0.Q5_K_M.gguf"
q6_k_gguf_model_id = "tinyllama-1.1b-chat-v1.0.Q6_K.gguf"
q4_k_m_stablelm_model_id = "stablelm-3b-4e1t.q4_k_m.gguf"
# imatrix
iq1_m_gguf_model_id = "TinyLlama-1.1B-Chat-v1.0-IQ1_M.gguf"
iq1_s_gguf_model_id = "TinyLlama-1.1B-Chat-v1.0-IQ1_S.gguf"
Expand All @@ -75,6 +79,7 @@ class GgufIntegrationTests(unittest.TestCase):
q4_0_qwen2_moe_model_id = "Qwen1.5-MoE-A2.7B-Chat.Q4_0.gguf"
q4_llama3_model_id = "Meta-Llama-3-8B-Q4_K_M.gguf"
fp16_bloom_model_id = "bloom-560m.fp16.gguf"
fp16_stablelm2_model_id = "stablelm-2-1_6b.fp16.gguf"
q8_bloom_model_id = "bloom-560m.q8_0.gguf"
f16_tinyllama_model_id = "TinyLlama-1.1B-Chat-v1.0.FP16.gguf"
q2_k_falcon7b_model_id = "falcon-7b-q2_k.gguf"
Expand Down Expand Up @@ -503,6 +508,75 @@ def test_falcon7b_weights_conversion_fp16(self):
self.assertTrue(original_params.shape == quantized_state_dict[layer_name].shape)
torch.testing.assert_close(original_params, quantized_state_dict[layer_name])

def test_stablelm_q4_k_m(self):
model = AutoModelForCausalLM.from_pretrained(
self.stablelm_model_id,
gguf_file=self.q4_k_m_stablelm_model_id,
device_map="auto",
torch_dtype=torch.float16,
)

tokenizer = AutoTokenizer.from_pretrained(self.stablelm_model_id, gguf_file=self.q4_k_m_stablelm_model_id)
text = tokenizer(self.example_text, return_tensors="pt").to(torch_device)
out = model.generate(**text, max_new_tokens=10)

EXPECTED_TEXT = "Hello-\nI am trying to create a new user"
self.assertEqual(tokenizer.decode(out[0], skip_special_tokens=True), EXPECTED_TEXT)

def test_stablelm_fp16(self):
original_model = AutoModelForCausalLM.from_pretrained(
self.original_stablelm2_model_id,
torch_dtype=torch.float16,
)

converted_model = AutoModelForCausalLM.from_pretrained(
self.stablelm2_model_id,
gguf_file=self.fp16_stablelm2_model_id,
torch_dtype=torch.float16,
# for precise comparison it is required to use the original model config
# as quantized one is different in parameters: use_parallel_residual and use_qkv_bias
# and it highly influences on the output results
config=original_model.config,
)

tokenizer = AutoTokenizer.from_pretrained(self.stablelm2_model_id, gguf_file=self.fp16_stablelm2_model_id)
text = tokenizer(self.example_text, return_tensors="pt")
original_out = original_model.generate(**text, max_new_tokens=10)
converted_out = converted_model.generate(**text, max_new_tokens=10)

EXPECTED_TEXT = "Hello, I am a 20 year old male"
self.assertEqual(tokenizer.decode(converted_out[0], skip_special_tokens=True), EXPECTED_TEXT)
self.assertEqual(
tokenizer.decode(converted_out[0], skip_special_tokens=True),
tokenizer.decode(original_out[0], skip_special_tokens=True),
)

def test_stablelm_weights_conversion_fp16(self):
original_model = AutoModelForCausalLM.from_pretrained(
self.original_stablelm2_model_id,
device_map="auto",
torch_dtype=torch.float16,
)

converted_model = AutoModelForCausalLM.from_pretrained(
self.stablelm2_model_id,
gguf_file=self.fp16_stablelm2_model_id,
device_map="auto",
torch_dtype=torch.float16,
# for precise comparison it is required to use the original model config
# as quantized one is different in parameters: use_parallel_residual and use_qkv_bias
# and it highly influences on the output results
config=original_model.config,
)

converted_state_dict = converted_model.state_dict()
original_state_dict = original_model.state_dict()

for layer_name, original_params in original_state_dict.items():
if layer_name in converted_state_dict:
self.assertTrue(original_params.shape == converted_state_dict[layer_name].shape)
torch.testing.assert_close(original_params, converted_state_dict[layer_name])

def test_tokenization_xnli(self):
import tqdm
from datasets import load_dataset
Expand Down
Loading