diff --git a/python/tvm/relay/frontend/pytorch.py b/python/tvm/relay/frontend/pytorch.py index 5e4d75599613..09e750bc1a0c 100644 --- a/python/tvm/relay/frontend/pytorch.py +++ b/python/tvm/relay/frontend/pytorch.py @@ -1622,6 +1622,18 @@ def view(self, inputs, input_types): return _op.transform.reshape(data, new_shape) + def view_as(self, inputs, input_types): + data = inputs[0] + tensors = inputs[1] + + if not isinstance(tensors, (_expr.Call, _expr.Constant, _expr.Var)): + msg = f"Data type {type(tensors)} could not be parsed in view_as op" + raise AssertionError(msg) + + shape = self.infer_shape(tensors) + + return _op.transform.reshape(data, shape) + def reshape(self, inputs, input_types): data = inputs[0] new_shape = inputs[1] @@ -3838,6 +3850,7 @@ def create_convert_map(self): "aten::addmm": self.addmm, "aten::size": self.size, "aten::view": self.view, + "aten::view_as": self.view_as, "aten::reshape": self.reshape, "aten::reshape_as": self.reshape_as, "aten::clone": self.clone, diff --git a/tests/python/frontend/pytorch/test_forward.py b/tests/python/frontend/pytorch/test_forward.py index 83930d1ea80b..75510b2608ad 100644 --- a/tests/python/frontend/pytorch/test_forward.py +++ b/tests/python/frontend/pytorch/test_forward.py @@ -1660,6 +1660,21 @@ def forward(self, *args): verify_model(View3().float().eval(), input_data=input_data) +@tvm.testing.uses_gpu +def test_forward_view_as(): + """test_forward_view_as""" + torch.set_grad_enabled(False) + input_shape = [1, 3, 10] + + class ViewAs1(Module): + def forward(self, *args): + t1 = torch.ones((1 * 3 * 10)) + return args[0].view_as(t1) + + input_data = torch.rand(input_shape).float() + verify_model(ViewAs1().float().eval(), input_data=input_data) + + @tvm.testing.uses_gpu def test_forward_select(): """test_forward_select"""