Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/source/en/gguf.md
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,7 @@ For now the supported model architectures are the architectures that have been v
- Falcon
- StableLM
- GPT2
- Starcoder2

## Example usage

Expand Down
24 changes: 24 additions & 0 deletions src/transformers/integrations/ggml.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,6 +176,20 @@
"ffn_up": "mlp.c_fc",
"ffn_down": "mlp.c_proj",
},
"starcoder2": {
"token_embd": "model.embed_tokens",
"blk": "model.layers",
"ffn_up": "mlp.c_fc",
"ffn_down": "mlp.c_proj",
"ffn_norm": "post_attention_layernorm",
"attn_norm": "input_layernorm",
"attn_q": "self_attn.q_proj",
"attn_v": "self_attn.v_proj",
"attn_k": "self_attn.k_proj",
"attn_output": "self_attn.o_proj",
"output.weight": "lm_head.weight",
"output_norm": "model.norm",
},
}


Expand Down Expand Up @@ -292,6 +306,15 @@
"attention.head_count": "n_head",
"attention.layer_norm_epsilon": "layer_norm_epsilon",
},
"starcoder2": {
"block_count": "num_hidden_layers",
"context_length": "max_position_embeddings",
"embedding_length": "hidden_size",
"feed_forward_length": "intermediate_size",
"attention.head_count": "num_attention_heads",
"attention.head_count_kv": "num_key_value_heads",
"attention.layer_norm_epsilon": "norm_epsilon",
},
}

GGUF_TOKENIZER_MAPPING = {
Expand Down Expand Up @@ -622,6 +645,7 @@ def converted(self) -> Tokenizer:
"falcon": GGUFGPTConverter,
"stablelm": GGUFGPTConverter,
"gpt2": GGUFGPTConverter,
"starcoder2": GGUFGPTConverter,
}


Expand Down
44 changes: 44 additions & 0 deletions tests/quantization/ggml/test_ggml.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,9 @@ class GgufIntegrationTests(unittest.TestCase):
gpt2_model_id = "mradermacher/gpt2-GGUF"
gpt2_original_model_id = "openai-community/gpt2"
gpt2_xl_model_id = "RichardErkhov/openai-community_-_gpt2-xl-gguf"
starcoder2_model_id = "QuantFactory/starcoder2-3b-GGUF"
starcoder2_fp16_model_id = "brittlewis12/starcoder2-3b-GGUF"
starcoder2_original_model_id = "bigcode/starcoder2-3b"

# standard quants
q4_0_gguf_model_id = "tinyllama-1.1b-chat-v1.0.Q4_0.gguf"
Expand Down Expand Up @@ -93,6 +96,8 @@ class GgufIntegrationTests(unittest.TestCase):
fp16_gpt2_model_id = "gpt2.f16.gguf"
q8_gpt2_model_id = "gpt2.Q8_0.gguf"
q6_k_gpt2_xl_model_id = "gpt2-xl.Q6_K.gguf"
q6_k_starcoder2_model_id = "starcoder2-3b.Q6_K.gguf"
fp16_starcoder2_gguf_model_id = "starcoder2-3b.fp16.gguf"

example_text = "Hello"

Expand Down Expand Up @@ -650,6 +655,45 @@ def test_stablelm_weights_conversion_fp16(self):
self.assertTrue(original_params.shape == converted_state_dict[layer_name].shape)
torch.testing.assert_close(original_params, converted_state_dict[layer_name])

def test_starcoder2_weights_conversion_fp16(self):
original_model = AutoModelForCausalLM.from_pretrained(
self.starcoder2_original_model_id,
device_map="auto",
torch_dtype=torch.float16,
)

converted_model = AutoModelForCausalLM.from_pretrained(
self.starcoder2_fp16_model_id,
gguf_file=self.fp16_starcoder2_gguf_model_id,
device_map="auto",
torch_dtype=torch.float16,
)

converted_state_dict = converted_model.state_dict()
original_state_dict = original_model.state_dict()

for layer_name, original_params in original_state_dict.items():
if layer_name in converted_state_dict and layer_name != "lm_head.weight":
# quantized models do not contain "lm_head.weight" layer
self.assertTrue(original_params.shape == converted_state_dict[layer_name].shape)
torch.testing.assert_close(original_params, converted_state_dict[layer_name])

def test_starcoder2_q6_k(self):
example_function_text = "def print_hello_world():"
model = AutoModelForCausalLM.from_pretrained(
self.starcoder2_model_id,
gguf_file=self.q6_k_starcoder2_model_id,
device_map="auto",
torch_dtype=torch.float16,
)

tokenizer = AutoTokenizer.from_pretrained(self.starcoder2_model_id, gguf_file=self.q6_k_starcoder2_model_id)
text = tokenizer(example_function_text, return_tensors="pt").to(torch_device)
out = model.generate(**text, max_new_tokens=10)

EXPECTED_TEXT = 'def print_hello_world():\n print("Hello World")\n\ndef print'
self.assertEqual(tokenizer.decode(out[0], skip_special_tokens=True), EXPECTED_TEXT)

def test_tokenization_xnli(self):
import tqdm
from datasets import load_dataset
Expand Down