diff --git a/test/dtypes/test_fbgemm_fp8.py b/test/dtypes/test_fbgemm_fp8.py index d2f1e2d82a..56cf5ea081 100644 --- a/test/dtypes/test_fbgemm_fp8.py +++ b/test/dtypes/test_fbgemm_fp8.py @@ -25,24 +25,87 @@ @unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_8, "Need pytorch 2.8+") +@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") +@unittest.skipIf(not is_sm_at_least_90(), "Nedd sm90+") class TestFbgemmFp8Tensor(TestCase): - @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") - @unittest.skipIf(not is_sm_at_least_90(), "Nedd sm90+") + def setUp(self): + self.config = FbgemmConfig( + input_dtype=e4m3_dtype, + weight_dtype=e4m3_dtype, + output_dtype=torch.bfloat16, + ) + def test_linear(self): dtype = torch.bfloat16 device = "cuda" input = torch.randn(1, 128, dtype=dtype, device=device) linear = torch.nn.Linear(128, 256, dtype=dtype, device=device) original = linear(input) - config = FbgemmConfig( - input_dtype=e4m3_dtype, - weight_dtype=e4m3_dtype, - output_dtype=torch.bfloat16, - ) - quantize_(linear, config) + quantize_(linear, self.config) quantized = linear(input) self.assertTrue(compute_error(original, quantized) > 20) + def test_slice(self): + dtype = torch.bfloat16 + device = "cuda" + dummy = torch.nn.Linear(256, 256, bias=False, dtype=dtype, device=device) + dummy1 = torch.nn.Linear(256, 64, bias=False, dtype=dtype, device=device) + dummy1.weight = torch.nn.Parameter( + dummy.weight.narrow(0, 0, 64), requires_grad=False + ) + dummy2 = torch.nn.Linear(128, 256, dtype=dtype, device=device) + dummy2.weight = torch.nn.Parameter( + dummy.weight.narrow(1, 0, 128), requires_grad=False + ) + + quantize_(dummy, self.config) + weight1 = dummy.weight.narrow(0, 0, 64) + weight2 = dummy.weight.narrow(1, 0, 128) + self.assertEqual(weight1.float8_data, dummy.weight.float8_data.narrow(0, 0, 64)) + self.assertEqual(weight1.scale, dummy.weight.scale.narrow(0, 0, 64)) + self.assertEqual( + weight2.float8_data, dummy.weight.float8_data.narrow(1, 0, 128) + ) + self.assertEqual(weight2.scale, dummy.weight.scale) + + # check for sliced weight, before and after float8 quantization + # does not differ too much + input = torch.randn(2, 256, dtype=dtype, device=device) + res_ref = dummy1(input) + dummy.weight = torch.nn.Parameter(weight1, requires_grad=False) + res = dummy(input) + assert compute_error(res, res_ref) > 25 + + input = torch.randn(2, 128, dtype=dtype, device=device) + res_ref = dummy2(input) + dummy.weight = torch.nn.Parameter(weight2, requires_grad=False) + res = dummy(input) + assert compute_error(res, res_ref) > 15 + + def test_slice_and_copy_(self): + l = torch.nn.Linear(1024, 1024).to("cuda").to(torch.bfloat16) + l.weight = torch.nn.Parameter( + torch.zeros(1024, 1024, dtype=torch.bfloat16, device="cuda") + ) + quantize_(l, self.config) + param = l.weight + param_data = param.data + param_data = param_data.narrow(0, 0, 512) + assert param.data.float8_data.data_ptr() == param_data.float8_data.data_ptr() + assert param.data.scale.data_ptr() == param_data.scale.data_ptr() + orig_value = param.data.float8_data[0][0].item() + + # dummy_l has random input (shouldn't be 0) + dummy_l = torch.nn.Linear(1024, 1024).to("cuda").to(torch.bfloat16) + quantize_(dummy_l, self.config) + quantized = dummy_l.weight + quantized = quantized.narrow(0, 0, 512) + + param_data.copy_(quantized) + + # making sure param.data is updated + assert param.data.float8_data[0][0] != orig_value + if __name__ == "__main__": run_tests() diff --git a/test/dtypes/test_fbgemm_int4.py b/test/dtypes/test_fbgemm_int4.py index 22fe5bc110..25b71f0244 100644 --- a/test/dtypes/test_fbgemm_int4.py +++ b/test/dtypes/test_fbgemm_int4.py @@ -24,25 +24,93 @@ @unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_8, "Need pytorch 2.8+") +@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") +@unittest.skipIf(not is_sm_at_least_90(), "Nedd sm90+") class TestFbgemmInt4Tensor(TestCase): - @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") - @unittest.skipIf(not is_sm_at_least_90(), "Nedd sm90+") + def setUp(self): + self.config = FbgemmConfig( + input_dtype=torch.bfloat16, + weight_dtype=torch.int4, + output_dtype=torch.bfloat16, + block_size=[1, 128], + ) + def test_linear(self): dtype = torch.bfloat16 device = "cuda" input = torch.randn(1, 128, dtype=dtype, device=device) linear = torch.nn.Linear(128, 256, dtype=dtype, device=device) original = linear(input) - config = FbgemmConfig( - input_dtype=torch.bfloat16, - weight_dtype=torch.int4, - output_dtype=torch.bfloat16, - block_size=[1, 128], - ) - quantize_(linear, config) + quantize_(linear, self.config) quantized = linear(input) self.assertTrue(compute_error(original, quantized) > 20) + def test_slice(self): + dtype = torch.bfloat16 + device = "cuda" + dummy = torch.nn.Linear(256, 256, bias=False, dtype=dtype, device=device) + dummy1 = torch.nn.Linear(256, 64, bias=False, dtype=dtype, device=device) + dummy1.weight = torch.nn.Parameter( + dummy.weight.narrow(0, 0, 64), requires_grad=False + ) + dummy2 = torch.nn.Linear(128, 256, dtype=dtype, device=device) + dummy2.weight = torch.nn.Parameter( + dummy.weight.narrow(1, 0, 128), requires_grad=False + ) + + quantize_(dummy, self.config) + weight1 = dummy.weight.narrow(0, 0, 64) + weight2 = dummy.weight.narrow(1, 0, 128) + self.assertEqual( + weight1.packed_weight, dummy.weight.packed_weight.narrow(0, 0, 64) + ) + self.assertEqual(weight1.scale, dummy.weight.scale.narrow(1, 0, 64)) + self.assertEqual( + weight2.packed_weight, dummy.weight.packed_weight.narrow(1, 0, 64) + ) + self.assertEqual(weight2.scale, dummy.weight.scale.narrow(0, 0, 1)) + + # check for sliced weight, before and after float8 quantization + # does not differ too much + input = torch.randn(2, 256, dtype=dtype, device=device) + res_ref = dummy1(input) + dummy.weight = torch.nn.Parameter(weight1, requires_grad=False) + res = dummy(input) + assert compute_error(res, res_ref) > 20 + + input = torch.randn(2, 128, dtype=dtype, device=device) + res_ref = dummy2(input) + dummy.weight = torch.nn.Parameter(weight2, requires_grad=False) + res = dummy(input) + assert compute_error(res, res_ref) > 15 + + def test_slice_and_copy_(self): + l = torch.nn.Linear(1024, 1024).to("cuda").to(torch.bfloat16) + l.weight = torch.nn.Parameter( + torch.zeros(1024, 1024, dtype=torch.bfloat16, device="cuda") + ) + quantize_(l, self.config) + param = l.weight + param_data = param.data + param_data = param_data.narrow(0, 0, 512) + assert ( + param.data.packed_weight.data_ptr() == param_data.packed_weight.data_ptr() + ) + assert param.data.scale.data_ptr() == param_data.scale.data_ptr() + assert param.data.zero_point.data_ptr() == param_data.zero_point.data_ptr() + orig_value = param.data.packed_weight[0][0].item() + + # dummy_l has random input (shouldn't be 0) + dummy_l = torch.nn.Linear(1024, 1024).to("cuda").to(torch.bfloat16) + quantize_(dummy_l, self.config) + quantized = dummy_l.weight + quantized = quantized.narrow(0, 0, 512) + + param_data.copy_(quantized) + + # making sure param.data is updated + assert param.data.packed_weight[0][0] != orig_value + if __name__ == "__main__": run_tests() diff --git a/torchao/dtypes/fbgemm_fp8_tensor.py b/torchao/dtypes/fbgemm_fp8_tensor.py index 735c21c2ca..df7ce69de7 100644 --- a/torchao/dtypes/fbgemm_fp8_tensor.py +++ b/torchao/dtypes/fbgemm_fp8_tensor.py @@ -13,6 +13,7 @@ from torchao.utils import ( TORCH_VERSION_AT_LEAST_2_5, TorchAOBaseTensor, + fill_defaults, ) __all__ = [ @@ -23,6 +24,10 @@ class FbgemmFp8Tensor(TorchAOBaseTensor): + """ + TODO: needs padding for cutlass kernels + """ + tensor_data_attrs = ["float8_data", "scale", "activation_scale_ub"] tensor_attributes = ["dtype"] @@ -118,9 +123,13 @@ def _(func, types, args, kwargs): xq, x_scale = torch.ops.fbgemm.quantize_fp8_per_row( input_tensor, num_tokens, weight_tensor.activation_scale_ub ) + + a_data = xq + b_data = weight_tensor.float8_data + res = torch.ops.fbgemm.f8f8bf16_rowwise( - xq, - weight_tensor.float8_data, + a_data, + b_data, x_scale, weight_tensor.scale, use_fast_accum=True, @@ -139,13 +148,87 @@ def _(func, types, args, kwargs): ) -@implements([aten.clone.default, aten.copy_.default]) +@implements(aten.clone.default) def _(func, types, args, kwargs): return return_and_correct_aliasing( func, args, kwargs, args[0]._apply_fn_to_data(torch.clone) ) +def _same_metadata(self: "FbgemmFp8Tensor", src: "FbgemmFp8Tensor") -> bool: + return ( + isinstance(self, FbgemmFp8Tensor) + and isinstance(src, FbgemmFp8Tensor) + and self.shape == src.shape + and self.float8_data.shape == src.float8_data.shape + and self.scale.shape == src.scale.shape + and self.activation_scale_ub.shape == src.activation_scale_ub.shape + and self.dtype == src.dtype + ) + + +@implements(aten.copy_.default) +def _(func, types, args, kwargs): + self = args[0] + src = args[1] + if _same_metadata(self, src): + self_tensors = self.__tensor_flatten__()[0] + for tensor_name in self_tensors: + getattr(self, tensor_name).copy_(getattr(src, tensor_name)) + return + raise ValueError( + f"Not supported args for copy_ due to metadata mismatch: {args[0], args[1]}" + ) + + +@implements(aten.slice.Tensor) +def _(func, types, args, kwargs): + """Only supports slicing for dim == 1 and dim == 2 + original tensor shape has dimension (N, K) + float8_data has dimension (N, K) + scale (per row quantization) has dimension: (N,) + + since float8_data has the same dimension as original tensor, we can directly slice that + for scale, we'll do a slice when dim is 0, and don't need to do anything for dim 1 + + Note that we need to call slice on the float8_data and scale directly because slice + is an operation that need to preserve aliasing, see `test_slice_and_copy_` in `test_fbgemm_fp8` + for + """ + self, dim, start, end, step = fill_defaults(args, 5, [0, None, None, 1]) + assert step == 1 + assert dim == 0 or dim == 1, f"Only dim==0 or 1 are supported, got: {dim}" + if end >= self.shape[dim]: + end = self.shape[dim] + + assert self.float8_data.ndim == 2, ( + f"Expected packed weight to have dim 2, got {self.float8_data.dim}" + ) + + # Always slice the float8_data + sliced_data = aten.slice.Tensor( + self.float8_data, dim, start, end, step + ).contiguous() + + if dim == 0: + # scale has dimension (N,) where N is the dim 0 of `self` + # so we do the same slice on scale for dimension 0 + sliced_scale = aten.slice.Tensor(self.scale, 0, start, end, step) + else: + # since scale is per row, slicing along the dim == 1 dimension does + # not change the scale + sliced_scale = self.scale + + return return_and_correct_aliasing( + func, + args, + kwargs, + FbgemmFp8Tensor( + sliced_data, sliced_scale, self.activation_scale_ub, dtype=self.dtype + ), + ) + + to_fbgemm_fp8 = FbgemmFp8Tensor.from_float diff --git a/torchao/dtypes/fbgemm_int4_tensor.py b/torchao/dtypes/fbgemm_int4_tensor.py index c2ab6246bf..ab108fea06 100644 --- a/torchao/dtypes/fbgemm_int4_tensor.py +++ b/torchao/dtypes/fbgemm_int4_tensor.py @@ -14,6 +14,7 @@ from torchao.utils import ( TORCH_VERSION_AT_LEAST_2_5, TorchAOBaseTensor, + fill_defaults, ) __all__ = [ @@ -32,17 +33,16 @@ class FbgemmInt4Tensor(TorchAOBaseTensor): tensor_data_attrs = ["packed_weight", "scale", "zero_point"] - tensor_attributes = ["group_size"] + tensor_attributes = ["group_size", "shape"] - def __new__(cls, packed_weight, scale, zero_point, group_size): - shape = packed_weight.shape + def __new__(cls, packed_weight, scale, zero_point, group_size, shape): kwargs = {} kwargs["device"] = packed_weight.device kwargs["dtype"] = scale.dtype kwargs["requires_grad"] = False return torch.Tensor._make_wrapper_subclass(cls, shape, **kwargs) # type: ignore[attr-defined] - def __init__(self, packed_weight, scale, zero_point, group_size): + def __init__(self, packed_weight, scale, zero_point, group_size, shape): self.packed_weight = packed_weight self.scale = scale self.zero_point = zero_point @@ -90,6 +90,7 @@ def from_float( raise ImportError("Requires fbgemm-gpu-genai >= 1.2.0") group_size = block_size[-1] + original_shape = w.shape if w.ndim >= 3: wq, scale, zero_point = zip( @@ -111,6 +112,7 @@ def from_float( scale=scale, zero_point=zero_point, group_size=group_size, + shape=original_shape, ) @@ -134,7 +136,7 @@ def _(func, types, args, kwargs): res = torch.ops.fbgemm.bf16i4bf16_rowwise( input_tensor, - weight_tensor.packed_weight, + weight_tensor.packed_weight.contiguous(), weight_tensor.scale, weight_tensor.zero_point, ) @@ -151,13 +153,115 @@ def _(func, types, args, kwargs): ) -@implements([aten.clone.default, aten.copy_.default]) +@implements(aten.clone.default) def _(func, types, args, kwargs): return return_and_correct_aliasing( func, args, kwargs, args[0]._apply_fn_to_data(torch.clone) ) +def _same_metadata(self: "FbgemmInt4Tensor", src: "FbgemmInt4Tensor") -> bool: + return ( + isinstance(self, FbgemmInt4Tensor) + and isinstance(src, FbgemmInt4Tensor) + and self.shape == src.shape + and self.packed_weight.shape == src.packed_weight.shape + and self.scale.shape == src.scale.shape + and self.zero_point.shape == src.zero_point.shape + and self.group_size == src.group_size + ) + + +@implements(aten.copy_.default) +def _(func, types, args, kwargs): + self = args[0] + src = args[1] + if _same_metadata(self, src): + self_tensors = self.__tensor_flatten__()[0] + for tensor_name in self_tensors: + getattr(self, tensor_name).copy_(getattr(src, tensor_name)) + return + raise ValueError( + f"Not supported args for copy_ due to metadata mismatch: {args[0], args[1]}" + ) + + +@implements(aten.slice.Tensor) +def _(func, types, args, kwargs): + """Only supports slicing for dim == 1 and dim == 2 + packed_weight has dimension: (N, K/2) + scale and zero_point has dimension: (K/groups, N) + + dim, start, end, step are args that's referring to the original tensor shape + which is (N, K), and we need to map that to the transformed weight shape of packed_weight, + scale and zero_point + + when dim == 0: we do a slice on packed_weight dim 0, and on dim 1 of scale and zero_point, + also adjust the start and end indexes based on the ratio between original shape and the shape + of packed_weight and scale/zero_point + + when dim == 1: we do a slice on packed_weight dim 1 and dim 0 of scale and zero_point and do the + same adjustment based on ratio + + Note that we need to call slice on the packed_weight, scale and zero_point directly because slice + is an operation that need to preserve aliasing, see `test_slice_and_copy_` in `test_fbgemm_int4` + for + """ + self, dim, start, end, step = fill_defaults(args, 5, [0, None, None, 1]) + assert step == 1 + assert dim == 0 or dim == 1, f"Only dim==0 or 1 are supported, got: {dim}" + if end >= self.shape[dim]: + end = self.shape[dim] + + assert self.packed_weight.ndim == 2, ( + f"Expected packed weight to have dim 2, got {self.packed_weight.dim}" + ) + N, K_by_2 = self.packed_weight.shape + sz_dim0, sz_dim1 = self.scale.shape + + data_len = self.shape[dim] + + if dim == 0: + pw_len = N + sz_len = sz_dim1 + else: + pw_len = K_by_2 + sz_len = sz_dim0 + + sz_dim = 1 - dim + if pw_len == 0 or sz_len == 0: + return return_and_correct_aliasing( + func, + args, + kwargs, + self.__class__( + self.packed_weight, + self.scale, + self.zero_point, + group_size=self.group_size, + shape=self.shape, + ), + ) + + pw_ratio = data_len / pw_len + start_pw = int(start / pw_ratio) + end_pw = int(end / pw_ratio) + + sz_ratio = data_len / sz_len + start_sz = int(start / sz_ratio) + end_sz = int(end / sz_ratio) + + packed_weight = aten.slice.Tensor(self.packed_weight, dim, start_pw, end_pw, step) + scale = aten.slice.Tensor(self.scale, sz_dim, start_sz, end_sz, step) + zero_point = aten.slice.Tensor(self.zero_point, sz_dim, start_sz, end_sz, step) + packed_shape0, packed_shape1 = packed_weight.shape + new_shape = (packed_shape0, packed_shape1 * 2) + new = self.__class__( + packed_weight, scale, zero_point, group_size=self.group_size, shape=new_shape + ) + return return_and_correct_aliasing(func, args, kwargs, new) + + to_fbgemm_int4 = FbgemmInt4Tensor.from_float