Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
63 changes: 31 additions & 32 deletions tests/modeling_test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -217,48 +217,47 @@ def set_dtype(model, dtype):
return model


def _generate_inputs(base_input, dtype, framework):
if isinstance(base_input, np.ndarray):
if base_input.dtype in (np.float16, np.float32, np.float64, bfloat16):
base_input = base_input.astype(dtype)

if framework == "torch":
return (
torch.from_numpy(base_input.astype(np.float32)).to(torch.bfloat16)
if dtype == bfloat16
else torch.from_numpy(base_input)
)
elif framework == "mindspore":
return ms.Tensor.from_numpy(base_input)
else:
raise ValueError(f"Unsupported framework: {framework}")

elif isinstance(base_input, (tuple, list)):
sequence_cls = type(base_input)
return sequence_cls(_generate_inputs(x, dtype, framework) for x in base_input)

elif isinstance(base_input, dict):
return {k: _generate_inputs(v, dtype, framework) for k, v in base_input.items()}

else:
return base_input


def generalized_parse_args(pt_dtype, ms_dtype, *args, **kwargs):
# parse args
pt_inputs_args = tuple()
ms_inputs_args = tuple()
for x in args:
if isinstance(x, np.ndarray):
if x.dtype in (np.float16, np.float32, np.float64, bfloat16):
px = x.astype(NP_DTYPE_MAPPING[pt_dtype])
mx = x.astype(NP_DTYPE_MAPPING[ms_dtype])
else:
px = mx = x

pt_inputs_args += (
(torch.from_numpy(px.astype(np.float32)).to(torch.bfloat16),)
if pt_dtype == "bf16"
else (torch.from_numpy(px),)
)
ms_inputs_args += (ms.Tensor.from_numpy(mx),)
else:
pt_inputs_args += (x,)
ms_inputs_args += (x,)
pt_inputs_args += (_generate_inputs(x, NP_DTYPE_MAPPING[pt_dtype], "torch"),)
ms_inputs_args += (_generate_inputs(x, NP_DTYPE_MAPPING[ms_dtype], "mindspore"),)

# parse kwargs
pt_inputs_kwargs = dict()
ms_inputs_kwargs = dict()
for k, v in kwargs.items():
if isinstance(v, np.ndarray):
if v.dtype in (np.float16, np.float32, np.float64, bfloat16):
px = v.astype(NP_DTYPE_MAPPING[pt_dtype])
mx = v.astype(NP_DTYPE_MAPPING[ms_dtype])
else:
px = mx = v

pt_inputs_kwargs[k] = (
torch.from_numpy(px.astype(np.float32)).to(torch.bfloat16)
if pt_dtype == "bf16"
else torch.from_numpy(px)
)
ms_inputs_kwargs[k] = ms.Tensor.from_numpy(mx)
else:
pt_inputs_kwargs[k] = v
ms_inputs_kwargs[k] = v
pt_inputs_kwargs[k] = _generate_inputs(v, NP_DTYPE_MAPPING[pt_dtype], "torch")
ms_inputs_kwargs[k] = _generate_inputs(v, NP_DTYPE_MAPPING[ms_dtype], "mindspore")

return pt_inputs_args, pt_inputs_kwargs, ms_inputs_args, ms_inputs_kwargs

Expand Down