Skip to content

Commit aa9effe

Browse files
author
z00573959
committed
fix bug of ut weight convert
1 parent 5a76add commit aa9effe

File tree

1 file changed

+5
-4
lines changed

1 file changed

+5
-4
lines changed

tests/modeling_test_utils.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -131,10 +131,6 @@ def get_pt2ms_mappings(m):
131131
mappings[f"{name}.running_mean"] = f"{name}.moving_mean", lambda x: x
132132
mappings[f"{name}.running_var"] = f"{name}.moving_variance", lambda x: x
133133
mappings[f"{name}.num_batches_tracked"] = None, lambda x: x
134-
elif isinstance(cell, (mint.nn.BatchNorm1d, mint.nn.BatchNorm2d, mint.nn.BatchNorm3d)):
135-
# TODO: for mint.nn, the dtype for each param should expected to be same among torch and mindspore
136-
# this is a temporary fix, delete this branch in future.
137-
mappings[f"{name}.num_batches_tracked"] = f"{name}.num_batches_tracked", lambda x: x.to(ms.float32)
138134
return mappings
139135

140136

@@ -150,6 +146,11 @@ def convert_state_dict(m, state_dict_pt):
150146
state_dict_ms = {}
151147
for name_pt, data_pt in state_dict_pt.items():
152148
name_ms, data_mapping = mappings.get(name_pt, (name_pt, lambda x: x))
149+
# for torch back compatibility
150+
# for torch <2.0, dtype of num_batches_tracked is int32, for torch>=2.0, dtype of num_batches_tracked is int64,
151+
# mindspore.mint is aligned with torch>=2.0
152+
if 'num_batches_tracked' in name_pt and data_pt.dtype == torch.int32:
153+
data_pt = data_pt.to(torch.int64)
153154
data_ms = ms.Parameter(
154155
data_mapping(ms.Tensor.from_numpy(data_pt.float().numpy()).to(dtype_mappings[data_pt.dtype])), name=name_ms
155156
)

0 commit comments

Comments
 (0)