Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions coremltools/converters/mil/frontend/torch/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
)
from coremltools.converters.mil.mil.types import is_bool, nptype_from_builtin
from coremltools.converters.mil.mil.types.symbolic import any_symbolic, is_symbolic
from coremltools.converters.mil.mil.types.type_mapping import builtin_to_string
from coremltools.converters.mil.mil.var import ListVar, Var

from .._utils import build_einsum_mil, value_at
Expand Down Expand Up @@ -4241,6 +4242,9 @@ def masked_fill(context, node):
# cond must be bool type
mask = mb.cast(x=mask, dtype="bool")

if value.dtype != x.dtype:
value = mb.cast(x=value, dtype=builtin_to_string(x.dtype))

res = mb.select(cond=mask, a=value, b=x, name=node.name)
context.add(res)

Expand Down
21 changes: 15 additions & 6 deletions coremltools/converters/mil/frontend/torch/test/test_torch_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -7927,21 +7927,30 @@ def test_constant_pad_3d(self, compute_unit, backend):

class TestMaskedFill(TorchBaseTest):
@pytest.mark.parametrize(
"compute_unit, backend",
itertools.product(compute_units, backends),
"compute_unit, backend, dtype, value",
itertools.product(
compute_units,
backends,
[np.int32, np.float32],
[10.3, 7, 0],
),
)
def test_masked_fill(self, compute_unit, backend):
def test_masked_fill(self, compute_unit, backend, dtype, value):
SHAPE = (2, 3)
MASK = torch.bernoulli(torch.rand(SHAPE[-1])).to(torch.bool)
VALUE = 10.0

model = ModuleWrapper(torch.masked_fill, {"mask": MASK, "value": VALUE})
input_data = np.random.randint(-100, 100, SHAPE).astype(dtype)
input_data = torch.from_numpy(input_data)
model = ModuleWrapper(torch.masked_fill, {"mask": MASK, "value": value})
converter_input_type = [TensorType(shape=SHAPE, dtype=dtype)]

TorchBaseTest.run_compare_torch(
SHAPE,
input_data,
model,
backend=backend,
compute_unit=compute_unit,
input_as_shape=False,
converter_input_type=converter_input_type,
)


Expand Down