Skip to content

Commit a584e24

Browse files
authored
Add empty_like for NF4Tensor to support offloading (#881)
Driss, we can confirm semantics when you're back but I'm fairly confident this is okay
1 parent 8aa6533 commit a584e24

File tree

2 files changed

+24
-0
lines changed

2 files changed

+24
-0
lines changed

test/dtypes/test_nf4.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
run_tests,
2121
)
2222
from torchao.dtypes.nf4tensor import (
23+
NF4Tensor,
2324
linear_nf4,
2425
to_nf4,
2526
_INNER_TENSOR_NAMES_FOR_SHARDING,
@@ -270,6 +271,15 @@ def test_chunk_size_equivalence(self, dtype: torch.dtype, shape, chunk_size):
270271

271272
torch.testing.assert_close(nf4_patched.quantized_data, nf4_base.quantized_data)
272273

274+
@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
275+
@parametrize("input_size", [(512 * 512,), (512, 512)])
276+
def test_empty_like(self, input_size: Union[Tuple[int], int]):
277+
nf4_tensor = to_nf4(torch.rand(input_size, device="cuda"))
278+
new_tensor = torch.empty_like(nf4_tensor, device="cpu")
279+
self.assertTrue(isinstance(new_tensor, NF4Tensor))
280+
self.assertEqual(new_tensor.get_device(), -1) # that it's on CPU
281+
self.assertEqual(new_tensor.size(), nf4_tensor.size())
282+
273283

274284
class TestFSDPOps(TestCase):
275285
@parametrize("input_size", [512 * 512, (512 * 512,), (512, 512)])

torchao/dtypes/nf4tensor.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -164,6 +164,20 @@ def nf4_detach(aten_op, args, kwargs=None):
164164
return NF4Tensor(*construct_nf4_args(nf4tensor, updated_attrs))
165165

166166

167+
@implements(
168+
[
169+
aten.empty_like.default,
170+
]
171+
)
172+
def nf4_empty_like(aten_op, args, kwargs=None):
173+
nf4tensor = args[0]
174+
updated_attrs = apply_to_inner_tensors(nf4tensor, aten_op, args[1:], kwargs)
175+
if kwargs is not None and len(kwargs):
176+
for key, value in kwargs.items():
177+
updated_attrs[key] = value
178+
return NF4Tensor(*construct_nf4_args(nf4tensor, updated_attrs))
179+
180+
167181
@implements(
168182
[
169183
aten.split.Tensor,

0 commit comments

Comments
 (0)