Skip to content

Add slicing support for fbgemm fp8 and int4 #2308

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jun 6, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
79 changes: 71 additions & 8 deletions test/dtypes/test_fbgemm_fp8.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,24 +25,87 @@


@unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_8, "Need pytorch 2.8+")
@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
@unittest.skipIf(not is_sm_at_least_90(), "Nedd sm90+")
class TestFbgemmFp8Tensor(TestCase):
@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
@unittest.skipIf(not is_sm_at_least_90(), "Nedd sm90+")
def setUp(self):
self.config = FbgemmConfig(
input_dtype=e4m3_dtype,
weight_dtype=e4m3_dtype,
output_dtype=torch.bfloat16,
)

def test_linear(self):
dtype = torch.bfloat16
device = "cuda"
input = torch.randn(1, 128, dtype=dtype, device=device)
linear = torch.nn.Linear(128, 256, dtype=dtype, device=device)
original = linear(input)
config = FbgemmConfig(
input_dtype=e4m3_dtype,
weight_dtype=e4m3_dtype,
output_dtype=torch.bfloat16,
)
quantize_(linear, config)
quantize_(linear, self.config)
quantized = linear(input)
self.assertTrue(compute_error(original, quantized) > 20)

def test_slice(self):
dtype = torch.bfloat16
device = "cuda"
dummy = torch.nn.Linear(256, 256, bias=False, dtype=dtype, device=device)
dummy1 = torch.nn.Linear(256, 64, bias=False, dtype=dtype, device=device)
dummy1.weight = torch.nn.Parameter(
dummy.weight.narrow(0, 0, 64), requires_grad=False
)
dummy2 = torch.nn.Linear(128, 256, dtype=dtype, device=device)
dummy2.weight = torch.nn.Parameter(
dummy.weight.narrow(1, 0, 128), requires_grad=False
)

quantize_(dummy, self.config)
weight1 = dummy.weight.narrow(0, 0, 64)
weight2 = dummy.weight.narrow(1, 0, 128)
self.assertEqual(weight1.float8_data, dummy.weight.float8_data.narrow(0, 0, 64))
self.assertEqual(weight1.scale, dummy.weight.scale.narrow(0, 0, 64))
self.assertEqual(
weight2.float8_data, dummy.weight.float8_data.narrow(1, 0, 128)
)
self.assertEqual(weight2.scale, dummy.weight.scale)

# check for sliced weight, before and after float8 quantization
# does not differ too much
input = torch.randn(2, 256, dtype=dtype, device=device)
res_ref = dummy1(input)
dummy.weight = torch.nn.Parameter(weight1, requires_grad=False)
res = dummy(input)
assert compute_error(res, res_ref) > 25

input = torch.randn(2, 128, dtype=dtype, device=device)
res_ref = dummy2(input)
dummy.weight = torch.nn.Parameter(weight2, requires_grad=False)
res = dummy(input)
assert compute_error(res, res_ref) > 15

def test_slice_and_copy_(self):
l = torch.nn.Linear(1024, 1024).to("cuda").to(torch.bfloat16)
l.weight = torch.nn.Parameter(
torch.zeros(1024, 1024, dtype=torch.bfloat16, device="cuda")
)
quantize_(l, self.config)
param = l.weight
param_data = param.data
param_data = param_data.narrow(0, 0, 512)
assert param.data.float8_data.data_ptr() == param_data.float8_data.data_ptr()
assert param.data.scale.data_ptr() == param_data.scale.data_ptr()
orig_value = param.data.float8_data[0][0].item()

# dummy_l has random input (shouldn't be 0)
dummy_l = torch.nn.Linear(1024, 1024).to("cuda").to(torch.bfloat16)
quantize_(dummy_l, self.config)
quantized = dummy_l.weight
quantized = quantized.narrow(0, 0, 512)

param_data.copy_(quantized)

# making sure param.data is updated
assert param.data.float8_data[0][0] != orig_value


if __name__ == "__main__":
run_tests()
86 changes: 77 additions & 9 deletions test/dtypes/test_fbgemm_int4.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,25 +24,93 @@


@unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_8, "Need pytorch 2.8+")
@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
@unittest.skipIf(not is_sm_at_least_90(), "Nedd sm90+")
class TestFbgemmInt4Tensor(TestCase):
@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
@unittest.skipIf(not is_sm_at_least_90(), "Nedd sm90+")
def setUp(self):
self.config = FbgemmConfig(
input_dtype=torch.bfloat16,
weight_dtype=torch.int4,
output_dtype=torch.bfloat16,
block_size=[1, 128],
)

def test_linear(self):
dtype = torch.bfloat16
device = "cuda"
input = torch.randn(1, 128, dtype=dtype, device=device)
linear = torch.nn.Linear(128, 256, dtype=dtype, device=device)
original = linear(input)
config = FbgemmConfig(
input_dtype=torch.bfloat16,
weight_dtype=torch.int4,
output_dtype=torch.bfloat16,
block_size=[1, 128],
)
quantize_(linear, config)
quantize_(linear, self.config)
quantized = linear(input)
self.assertTrue(compute_error(original, quantized) > 20)

def test_slice(self):
dtype = torch.bfloat16
device = "cuda"
dummy = torch.nn.Linear(256, 256, bias=False, dtype=dtype, device=device)
dummy1 = torch.nn.Linear(256, 64, bias=False, dtype=dtype, device=device)
dummy1.weight = torch.nn.Parameter(
dummy.weight.narrow(0, 0, 64), requires_grad=False
)
dummy2 = torch.nn.Linear(128, 256, dtype=dtype, device=device)
dummy2.weight = torch.nn.Parameter(
dummy.weight.narrow(1, 0, 128), requires_grad=False
)

quantize_(dummy, self.config)
weight1 = dummy.weight.narrow(0, 0, 64)
weight2 = dummy.weight.narrow(1, 0, 128)
self.assertEqual(
weight1.packed_weight, dummy.weight.packed_weight.narrow(0, 0, 64)
)
self.assertEqual(weight1.scale, dummy.weight.scale.narrow(1, 0, 64))
self.assertEqual(
weight2.packed_weight, dummy.weight.packed_weight.narrow(1, 0, 64)
)
self.assertEqual(weight2.scale, dummy.weight.scale.narrow(0, 0, 1))

# check for sliced weight, before and after float8 quantization
# does not differ too much
input = torch.randn(2, 256, dtype=dtype, device=device)
res_ref = dummy1(input)
dummy.weight = torch.nn.Parameter(weight1, requires_grad=False)
res = dummy(input)
assert compute_error(res, res_ref) > 20

input = torch.randn(2, 128, dtype=dtype, device=device)
res_ref = dummy2(input)
dummy.weight = torch.nn.Parameter(weight2, requires_grad=False)
res = dummy(input)
assert compute_error(res, res_ref) > 15

def test_slice_and_copy_(self):
l = torch.nn.Linear(1024, 1024).to("cuda").to(torch.bfloat16)
l.weight = torch.nn.Parameter(
torch.zeros(1024, 1024, dtype=torch.bfloat16, device="cuda")
)
quantize_(l, self.config)
param = l.weight
param_data = param.data
param_data = param_data.narrow(0, 0, 512)
assert (
param.data.packed_weight.data_ptr() == param_data.packed_weight.data_ptr()
)
assert param.data.scale.data_ptr() == param_data.scale.data_ptr()
assert param.data.zero_point.data_ptr() == param_data.zero_point.data_ptr()
orig_value = param.data.packed_weight[0][0].item()

# dummy_l has random input (shouldn't be 0)
dummy_l = torch.nn.Linear(1024, 1024).to("cuda").to(torch.bfloat16)
quantize_(dummy_l, self.config)
quantized = dummy_l.weight
quantized = quantized.narrow(0, 0, 512)

param_data.copy_(quantized)

# making sure param.data is updated
assert param.data.packed_weight[0][0] != orig_value


if __name__ == "__main__":
run_tests()
89 changes: 86 additions & 3 deletions torchao/dtypes/fbgemm_fp8_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from torchao.utils import (
TORCH_VERSION_AT_LEAST_2_5,
TorchAOBaseTensor,
fill_defaults,
)

__all__ = [
Expand All @@ -23,6 +24,10 @@


class FbgemmFp8Tensor(TorchAOBaseTensor):
"""
TODO: needs padding for cutlass kernels
"""

tensor_data_attrs = ["float8_data", "scale", "activation_scale_ub"]
tensor_attributes = ["dtype"]

Expand Down Expand Up @@ -118,9 +123,13 @@ def _(func, types, args, kwargs):
xq, x_scale = torch.ops.fbgemm.quantize_fp8_per_row(
input_tensor, num_tokens, weight_tensor.activation_scale_ub
)

a_data = xq
b_data = weight_tensor.float8_data

res = torch.ops.fbgemm.f8f8bf16_rowwise(
xq,
weight_tensor.float8_data,
a_data,
b_data,
x_scale,
weight_tensor.scale,
use_fast_accum=True,
Expand All @@ -139,13 +148,87 @@ def _(func, types, args, kwargs):
)


@implements([aten.clone.default, aten.copy_.default])
@implements(aten.clone.default)
def _(func, types, args, kwargs):
return return_and_correct_aliasing(
func, args, kwargs, args[0]._apply_fn_to_data(torch.clone)
)


def _same_metadata(self: "FbgemmFp8Tensor", src: "FbgemmFp8Tensor") -> bool:
return (
isinstance(self, FbgemmFp8Tensor)
and isinstance(src, FbgemmFp8Tensor)
and self.shape == src.shape
and self.float8_data.shape == src.float8_data.shape
and self.scale.shape == src.scale.shape
and self.activation_scale_ub.shape == src.activation_scale_ub.shape
and self.dtype == src.dtype
)


@implements(aten.copy_.default)
def _(func, types, args, kwargs):
self = args[0]
src = args[1]
if _same_metadata(self, src):
self_tensors = self.__tensor_flatten__()[0]
for tensor_name in self_tensors:
getattr(self, tensor_name).copy_(getattr(src, tensor_name))
return
raise ValueError(
f"Not supported args for copy_ due to metadata mismatch: {args[0], args[1]}"
)


@implements(aten.slice.Tensor)
def _(func, types, args, kwargs):
"""Only supports slicing for dim == 1 and dim == 2
original tensor shape has dimension (N, K)
float8_data has dimension (N, K)
scale (per row quantization) has dimension: (N,)

since float8_data has the same dimension as original tensor, we can directly slice that
for scale, we'll do a slice when dim is 0, and don't need to do anything for dim 1

Note that we need to call slice on the float8_data and scale directly because slice
is an operation that need to preserve aliasing, see `test_slice_and_copy_` in `test_fbgemm_fp8`
for
"""
self, dim, start, end, step = fill_defaults(args, 5, [0, None, None, 1])
assert step == 1
assert dim == 0 or dim == 1, f"Only dim==0 or 1 are supported, got: {dim}"
if end >= self.shape[dim]:
end = self.shape[dim]

assert self.float8_data.ndim == 2, (
f"Expected packed weight to have dim 2, got {self.float8_data.dim}"
)

# Always slice the float8_data
sliced_data = aten.slice.Tensor(
self.float8_data, dim, start, end, step
).contiguous()

if dim == 0:
# scale has dimension (N,) where N is the dim 0 of `self`
# so we do the same slice on scale for dimension 0
sliced_scale = aten.slice.Tensor(self.scale, 0, start, end, step)
else:
# since scale is per row, slicing along the dim == 1 dimension does
# not change the scale
sliced_scale = self.scale

return return_and_correct_aliasing(
func,
args,
kwargs,
FbgemmFp8Tensor(
sliced_data, sliced_scale, self.activation_scale_ub, dtype=self.dtype
),
)


to_fbgemm_fp8 = FbgemmFp8Tensor.from_float


Expand Down
Loading
Loading