Skip to content

Commit dd40af2

Browse files
author
WongGawa
committed
feat: add zamba2 model pipeline
1 parent 1a41958 commit dd40af2

File tree

7 files changed

+1859
-0
lines changed

7 files changed

+1859
-0
lines changed

mindone/transformers/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -265,5 +265,6 @@
265265
WhisperProcessor,
266266
)
267267
from .models.xlm_roberta import XLMRobertaModel, XLMRobertaPreTrainedModel
268+
from .models.zamba2 import Zamba2ForCausalLM, Zamba2ForSequenceClassification, Zamba2Model, Zamba2PreTrainedModel
268269
from .pipelines import TextGenerationPipeline, pipeline
269270
from .processing_utils import ProcessorMixin

mindone/transformers/models/auto/configuration_auto.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,7 @@
7171
("wav2vec2", "Wav2Vec2Config"),
7272
("whisper", "WhisperConfig"),
7373
("xlm-roberta", "XLMRobertaConfig"),
74+
("zamba2", "Zamba2Config"),
7475
]
7576
)
7677

@@ -126,6 +127,7 @@
126127
("whisper", "Whisper"),
127128
("xlm-roberta", "XLM-RoBERTa"),
128129
("xlm-roberta-xl", "XLM-RoBERTa-XL"),
130+
("zamba2", "Zamba2"),
129131
]
130132
)
131133

mindone/transformers/models/auto/modeling_auto.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,7 @@
6969
("wav2vec2", "Wav2Vec2Model"),
7070
("whisper", "WhisperModel"),
7171
("xlm-roberta", "XLMRobertaModel"),
72+
("zamba2", "Zamba2Model"),
7273
]
7374
)
7475

@@ -142,6 +143,7 @@
142143
("ijepa", "IJepaModel"),
143144
("imagegpt", "ImageGPTModel"),
144145
("levit", "LevitModel"),
146+
("zamba2", "Zamba2ForCausalLM"),
145147
]
146148
)
147149

@@ -286,6 +288,7 @@
286288
("qwen3", "Qwen3ForSequenceClassification"),
287289
("t5", "T5ForSequenceClassification"),
288290
("umt5", "UMT5ForSequenceClassification"),
291+
("zamba2", "Zamba2ForSequenceClassification"),
289292
]
290293
)
291294

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
from .modeling_zamba2 import Zamba2ForCausalLM, Zamba2ForSequenceClassification, Zamba2Model, Zamba2PreTrainedModel

mindone/transformers/models/zamba2/modeling_zamba2.py

Lines changed: 1617 additions & 0 deletions
Large diffs are not rendered by default.

tests/transformers_tests/models/zamba2/__init__.py

Whitespace-only changes.
Lines changed: 235 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,235 @@
1+
# This module contains test cases that are defined in the `.test_cases.py` file, structured as lists or tuples like
2+
# [name, pt_module, ms_module, init_args, init_kwargs, inputs_args, inputs_kwargs, outputs_map].
3+
#
4+
# Each defined case corresponds to a pair consisting of PyTorch and MindSpore modules, including their respective
5+
# initialization parameters and inputs for the forward. The testing framework adopted here is designed to generically
6+
# parse these parameters to assess and compare the precision of forward outcomes between the two frameworks.
7+
#
8+
# In cases where models have unique initialization procedures or require testing with specialized output formats,
9+
# it is necessary to develop distinct, dedicated test cases.
10+
11+
import inspect
12+
13+
import numpy as np
14+
import pytest
15+
import torch
16+
from transformers import Zamba2Config
17+
18+
import mindspore as ms
19+
20+
from tests.modeling_test_utils import (
21+
MS_DTYPE_MAPPING,
22+
PT_DTYPE_MAPPING,
23+
compute_diffs,
24+
generalized_parse_args,
25+
get_modules,
26+
)
27+
from tests.transformers_tests.models.modeling_common import ids_numpy
28+
29+
DTYPE_AND_THRESHOLDS = {"fp32": 5e-4, "fp16": 5e-3, "bf16": 5e-2}
30+
MODES = [1]
31+
32+
33+
class Zamba2ModelTester:
34+
config_class = Zamba2Config
35+
36+
def __init__(
37+
self,
38+
batch_size=13,
39+
seq_length=7,
40+
is_training=True,
41+
use_input_mask=True,
42+
use_token_type_ids=False,
43+
use_labels=True,
44+
vocab_size=99,
45+
hidden_size=32,
46+
num_hidden_layers=54,
47+
num_attention_heads=4,
48+
num_key_value_heads=2,
49+
intermediate_size=37,
50+
hidden_act="gelu",
51+
hidden_dropout_prob=0.1,
52+
attention_probs_dropout_prob=0.1,
53+
max_position_embeddings=512,
54+
type_vocab_size=16,
55+
type_sequence_label_size=2,
56+
initializer_range=0.02,
57+
num_labels=3,
58+
num_choices=4,
59+
pad_token_id=0,
60+
scope=None,
61+
):
62+
self.batch_size = batch_size
63+
self.seq_length = seq_length
64+
self.is_training = is_training
65+
self.use_input_mask = use_input_mask
66+
self.use_token_type_ids = use_token_type_ids
67+
self.use_labels = use_labels
68+
self.vocab_size = vocab_size
69+
self.hidden_size = hidden_size
70+
self.num_hidden_layers = num_hidden_layers
71+
self.num_attention_heads = num_attention_heads
72+
self.num_key_value_heads = num_key_value_heads
73+
self.intermediate_size = intermediate_size
74+
self.hidden_act = hidden_act
75+
self.hidden_dropout_prob = hidden_dropout_prob
76+
self.attention_probs_dropout_prob = attention_probs_dropout_prob
77+
self.max_position_embeddings = max_position_embeddings
78+
self.type_vocab_size = type_vocab_size
79+
self.type_sequence_label_size = type_sequence_label_size
80+
self.initializer_range = initializer_range
81+
self.num_labels = num_labels
82+
self.num_choices = num_choices
83+
self.pad_token_id = pad_token_id
84+
self.scope = scope
85+
self.head_dim = self.hidden_size // self.num_attention_heads
86+
87+
# Copied from tests.models.mistral.test_modeling_mistral.MistralModelTester.prepare_config_and_inputs
88+
def prepare_config_and_inputs(self):
89+
input_ids = ids_numpy([self.batch_size, self.seq_length], self.vocab_size)
90+
91+
input_mask = None
92+
if self.use_input_mask:
93+
input_mask = np.tril(np.ones_like(input_ids))
94+
95+
token_type_ids = None
96+
if self.use_token_type_ids:
97+
token_type_ids = ids_numpy([self.batch_size, self.seq_length], self.type_vocab_size)
98+
99+
sequence_labels = None
100+
token_labels = None
101+
choice_labels = None
102+
if self.use_labels:
103+
sequence_labels = ids_numpy([self.batch_size], self.type_sequence_label_size)
104+
token_labels = ids_numpy([self.batch_size, self.seq_length], self.num_labels)
105+
choice_labels = ids_numpy([self.batch_size], self.num_choices)
106+
107+
config = self.get_config()
108+
109+
return config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
110+
111+
def get_config(self):
112+
return self.config_class(
113+
attn_implementation="eager",
114+
vocab_size=self.vocab_size,
115+
hidden_size=self.hidden_size,
116+
num_hidden_layers=self.num_hidden_layers,
117+
num_attention_heads=self.num_attention_heads,
118+
num_key_value_heads=self.num_key_value_heads,
119+
intermediate_size=self.intermediate_size,
120+
hidden_act=self.hidden_act,
121+
hidden_dropout_prob=self.hidden_dropout_prob,
122+
attention_probs_dropout_prob=self.attention_probs_dropout_prob,
123+
max_position_embeddings=self.max_position_embeddings,
124+
type_vocab_size=self.type_vocab_size,
125+
is_decoder=False,
126+
initializer_range=self.initializer_range,
127+
pad_token_id=self.pad_token_id,
128+
head_dim=self.head_dim,
129+
)
130+
131+
132+
model_tester = Zamba2ModelTester()
133+
(
134+
config,
135+
input_ids,
136+
token_type_ids,
137+
input_mask,
138+
sequence_labels,
139+
token_labels,
140+
choice_labels,
141+
) = model_tester.prepare_config_and_inputs()
142+
143+
144+
Zamba2_CASES = [
145+
[
146+
"Zamba2Model",
147+
"transformers.Zamba2Model",
148+
"mindone.transformers.Zamba2Model",
149+
(config,),
150+
{},
151+
(input_ids,),
152+
{
153+
"attention_mask": input_mask,
154+
},
155+
{
156+
"last_hidden_state": 0,
157+
},
158+
],
159+
]
160+
161+
162+
# transformers need >= 4.41.2
163+
@pytest.mark.parametrize(
164+
"name,pt_module,ms_module,init_args,init_kwargs,inputs_args,inputs_kwargs,outputs_map,dtype,mode",
165+
[
166+
case
167+
+ [
168+
dtype,
169+
]
170+
+ [
171+
mode,
172+
]
173+
for case in Zamba2_CASES
174+
for dtype in DTYPE_AND_THRESHOLDS.keys()
175+
for mode in MODES
176+
],
177+
)
178+
def test_named_modules(
179+
name,
180+
pt_module,
181+
ms_module,
182+
init_args,
183+
init_kwargs,
184+
inputs_args,
185+
inputs_kwargs,
186+
outputs_map,
187+
dtype,
188+
mode,
189+
):
190+
ms.set_context(mode=mode)
191+
192+
(
193+
pt_model,
194+
ms_model,
195+
pt_dtype,
196+
ms_dtype,
197+
) = get_modules(pt_module, ms_module, dtype, *init_args, **init_kwargs)
198+
pt_inputs_args, pt_inputs_kwargs, ms_inputs_args, ms_inputs_kwargs = generalized_parse_args(
199+
pt_dtype, ms_dtype, *inputs_args, **inputs_kwargs
200+
)
201+
202+
# set `hidden_dtype` if requiring, for some modules always compute in float
203+
# precision and require specific `hidden_dtype` to cast before return
204+
if "hidden_dtype" in inspect.signature(pt_model.forward).parameters:
205+
pt_inputs_kwargs.update({"hidden_dtype": PT_DTYPE_MAPPING[pt_dtype]})
206+
ms_inputs_kwargs.update({"hidden_dtype": MS_DTYPE_MAPPING[ms_dtype]})
207+
if mode == 0:
208+
ms_inputs_kwargs.update({"use_cache": False})
209+
with torch.no_grad():
210+
pt_outputs = pt_model(*pt_inputs_args, **pt_inputs_kwargs)
211+
ms_outputs = ms_model(*ms_inputs_args, **ms_inputs_kwargs)
212+
# print("ms:", ms_outputs)
213+
# print("pt:", pt_outputs)
214+
if outputs_map:
215+
pt_outputs_n = []
216+
ms_outputs_n = []
217+
for pt_key, ms_idx in outputs_map.items():
218+
# print("===map", pt_key, ms_idx)
219+
pt_output = getattr(pt_outputs, pt_key)
220+
ms_output = ms_outputs[ms_idx]
221+
if isinstance(pt_output, (list, tuple)):
222+
pt_outputs_n += list(pt_output)
223+
ms_outputs_n += list(ms_output)
224+
else:
225+
pt_outputs_n.append(pt_output)
226+
ms_outputs_n.append(ms_output)
227+
diffs = compute_diffs(pt_outputs_n, ms_outputs_n)
228+
else:
229+
diffs = compute_diffs(pt_outputs, ms_outputs)
230+
231+
THRESHOLD = DTYPE_AND_THRESHOLDS[ms_dtype]
232+
assert (np.array(diffs) < THRESHOLD).all(), (
233+
f"ms_dtype: {ms_dtype}, pt_type:{pt_dtype}, "
234+
f"Outputs({np.array(diffs).tolist()}) has diff bigger than {THRESHOLD}"
235+
)

0 commit comments

Comments
 (0)