@@ -131,10 +131,6 @@ def get_pt2ms_mappings(m):
131131                mappings [f"{ name }  ] =  f"{ name }  , lambda  x : x 
132132                mappings [f"{ name }  ] =  f"{ name }  , lambda  x : x 
133133                mappings [f"{ name }  ] =  None , lambda  x : x 
134-         elif  isinstance (cell , (mint .nn .BatchNorm1d , mint .nn .BatchNorm2d , mint .nn .BatchNorm3d )):
135-             # TODO: for mint.nn, the dtype for each param should expected to be same among torch and mindspore 
136-             # this is a temporary fix, delete this branch in future. 
137-             mappings [f"{ name }  ] =  f"{ name }  , lambda  x : x .to (ms .float32 )
138134    return  mappings 
139135
140136
@@ -150,6 +146,11 @@ def convert_state_dict(m, state_dict_pt):
150146    state_dict_ms  =  {}
151147    for  name_pt , data_pt  in  state_dict_pt .items ():
152148        name_ms , data_mapping  =  mappings .get (name_pt , (name_pt , lambda  x : x ))
149+         # for torch back compatibility 
150+         # for torch <2.0, dtype of num_batches_tracked is int32, for torch>=2.0, dtype of num_batches_tracked is int64, 
151+         # mindspore.mint is aligned with torch>=2.0 
152+         if  'num_batches_tracked'  in  name_pt  and  data_pt .dtype  ==  torch .int32 :
153+             data_pt  =  data_pt .to (torch .int64 )
153154        data_ms  =  ms .Parameter (
154155            data_mapping (ms .Tensor .from_numpy (data_pt .float ().numpy ()).to (dtype_mappings [data_pt .dtype ])), name = name_ms 
155156        )
0 commit comments