diff --git a/test/dtypes/test_nf4.py b/test/dtypes/test_nf4.py index 14db986297..6649c1dd54 100644 --- a/test/dtypes/test_nf4.py +++ b/test/dtypes/test_nf4.py @@ -20,6 +20,7 @@ run_tests, ) from torchao.dtypes.nf4tensor import ( + NF4Tensor, linear_nf4, to_nf4, _INNER_TENSOR_NAMES_FOR_SHARDING, @@ -270,6 +271,15 @@ def test_chunk_size_equivalence(self, dtype: torch.dtype, shape, chunk_size): torch.testing.assert_close(nf4_patched.quantized_data, nf4_base.quantized_data) + @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") + @parametrize("input_size", [(512 * 512,), (512, 512)]) + def test_empty_like(self, input_size: Union[Tuple[int], int]): + nf4_tensor = to_nf4(torch.rand(input_size, device="cuda")) + new_tensor = torch.empty_like(nf4_tensor, device="cpu") + self.assertTrue(isinstance(new_tensor, NF4Tensor)) + self.assertEqual(new_tensor.get_device(), -1) # that it's on CPU + self.assertEqual(new_tensor.size(), nf4_tensor.size()) + class TestFSDPOps(TestCase): @parametrize("input_size", [512 * 512, (512 * 512,), (512, 512)]) diff --git a/torchao/dtypes/nf4tensor.py b/torchao/dtypes/nf4tensor.py index f46e43f91d..40424ab400 100644 --- a/torchao/dtypes/nf4tensor.py +++ b/torchao/dtypes/nf4tensor.py @@ -164,6 +164,20 @@ def nf4_detach(aten_op, args, kwargs=None): return NF4Tensor(*construct_nf4_args(nf4tensor, updated_attrs)) +@implements( + [ + aten.empty_like.default, + ] +) +def nf4_empty_like(aten_op, args, kwargs=None): + nf4tensor = args[0] + updated_attrs = apply_to_inner_tensors(nf4tensor, aten_op, args[1:], kwargs) + if kwargs is not None and len(kwargs): + for key, value in kwargs.items(): + updated_attrs[key] = value + return NF4Tensor(*construct_nf4_args(nf4tensor, updated_attrs)) + + @implements( [ aten.split.Tensor,