diff --git a/test/dtypes/test_nf4.py b/test/dtypes/test_nf4.py index 37d6f3d5da..9f1c373307 100644 --- a/test/dtypes/test_nf4.py +++ b/test/dtypes/test_nf4.py @@ -30,6 +30,7 @@ _INNER_TENSOR_NAMES_FOR_SHARDING, NF4Tensor, linear_nf4, + nf4_weight_only, to_nf4, ) @@ -281,6 +282,32 @@ def test_empty_like(self, input_size: Union[Tuple[int], int]): self.assertEqual(new_tensor.get_device(), -1) # that it's on CPU self.assertEqual(new_tensor.size(), nf4_tensor.size()) + @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") + @parametrize("compile", [False, True]) + def test_quantize_api(self, compile): + nf4_linear = nn.Linear(512, 512, device="cuda") + torchao.quantize_(nf4_linear, nf4_weight_only()) + assert isinstance(nf4_linear.weight, NF4Tensor) + + ref_linear = copy.deepcopy(nf4_linear) + ref_linear.weight.data = ref_linear.weight.get_original_weight() # dequantize + + if compile: + nf4_linear.compile() + ref_linear.compile() + + nf4_x = torch.randn(2, 512, device="cuda").requires_grad_() + ref_x = nf4_x.detach().clone().requires_grad_() + + nf4_out = nf4_linear(nf4_x) + ref_out = ref_linear(ref_x) + self.assertEqual(nf4_out, ref_out) + + grad_out = torch.randn(2, 512, device="cuda") + nf4_out.backward(grad_out) + ref_out.backward(grad_out) + self.assertEqual(nf4_x.grad, ref_x.grad) + class TestFSDPOps(TestCase): @parametrize("input_size", [512 * 512, (512 * 512,), (512, 512)]) diff --git a/torchao/dtypes/nf4tensor.py b/torchao/dtypes/nf4tensor.py index 1dc74c9916..617fd0871b 100644 --- a/torchao/dtypes/nf4tensor.py +++ b/torchao/dtypes/nf4tensor.py @@ -954,6 +954,15 @@ def to_nf4(tensor, block_size: int = 64, scaler_block_size: int = 256): return NF4Tensor.from_tensor(tensor, block_size, scaler_block_size) +def nf4_weight_only(block_size: int = 64, scaler_block_size: int = 256): + from torchao.quantization.quant_api import _get_linear_subclass_inserter + + def _to_nf4(tensor: torch.Tensor): + return NF4Tensor.from_tensor(tensor, block_size, scaler_block_size) + + return _get_linear_subclass_inserter(_to_nf4) + + NF4_TORCH_FUNCTIONS = {} @@ -1000,6 +1009,17 @@ def function_cpu(*args, **kwargs): return NF4Tensor(*construct_nf4_args(nf4tensor, updated_attrs)) +@implements_torch_function(F.linear) +def _(*args, **kwargs): + input = args[0] + weight = args[1] + bias = args[2] if len(args) > 2 else None + out = LinearNF4.apply(input, weight) + if bias is not None: + out = out + bias + return out + + @torch._dynamo.allow_in_graph def nf4_constructor( tensor_meta: SubclassTensorArgs,