Skip to content

[NF4] Add quantize_() API support for NF4 #1216

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Nov 4, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 27 additions & 0 deletions test/dtypes/test_nf4.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
_INNER_TENSOR_NAMES_FOR_SHARDING,
NF4Tensor,
linear_nf4,
nf4_weight_only,
to_nf4,
)

Expand Down Expand Up @@ -281,6 +282,32 @@ def test_empty_like(self, input_size: Union[Tuple[int], int]):
self.assertEqual(new_tensor.get_device(), -1) # that it's on CPU
self.assertEqual(new_tensor.size(), nf4_tensor.size())

@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
@parametrize("compile", [False, True])
def test_quantize_api(self, compile):
nf4_linear = nn.Linear(512, 512, device="cuda")
torchao.quantize_(nf4_linear, nf4_weight_only())
assert isinstance(nf4_linear.weight, NF4Tensor)

ref_linear = copy.deepcopy(nf4_linear)
ref_linear.weight.data = ref_linear.weight.get_original_weight() # dequantize

if compile:
nf4_linear.compile()
ref_linear.compile()

nf4_x = torch.randn(2, 512, device="cuda").requires_grad_()
ref_x = nf4_x.detach().clone().requires_grad_()

nf4_out = nf4_linear(nf4_x)
ref_out = ref_linear(ref_x)
self.assertEqual(nf4_out, ref_out)

grad_out = torch.randn(2, 512, device="cuda")
nf4_out.backward(grad_out)
ref_out.backward(grad_out)
self.assertEqual(nf4_x.grad, ref_x.grad)


class TestFSDPOps(TestCase):
@parametrize("input_size", [512 * 512, (512 * 512,), (512, 512)])
Expand Down
20 changes: 20 additions & 0 deletions torchao/dtypes/nf4tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -954,6 +954,15 @@ def to_nf4(tensor, block_size: int = 64, scaler_block_size: int = 256):
return NF4Tensor.from_tensor(tensor, block_size, scaler_block_size)


def nf4_weight_only(block_size: int = 64, scaler_block_size: int = 256):
from torchao.quantization.quant_api import _get_linear_subclass_inserter

def _to_nf4(tensor: torch.Tensor):
return NF4Tensor.from_tensor(tensor, block_size, scaler_block_size)

return _get_linear_subclass_inserter(_to_nf4)


NF4_TORCH_FUNCTIONS = {}


Expand Down Expand Up @@ -1000,6 +1009,17 @@ def function_cpu(*args, **kwargs):
return NF4Tensor(*construct_nf4_args(nf4tensor, updated_attrs))


@implements_torch_function(F.linear)
def _(*args, **kwargs):
input = args[0]
weight = args[1]
bias = args[2] if len(args) > 2 else None
out = LinearNF4.apply(input, weight)
if bias is not None:
out = out + bias
return out


@torch._dynamo.allow_in_graph
def nf4_constructor(
tensor_meta: SubclassTensorArgs,
Expand Down
Loading