Skip to content

Feat/blockwise fp8 quant #1668

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

102 changes: 102 additions & 0 deletions benchmarks/benchmark_blockwise_scaled_linear_triton.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
import pandas as pd
import torch
from tqdm import tqdm
from triton.testing import do_bench

from torchao.float8.float8_utils import compute_error
from torchao.ops import rowwise_scaled_linear_cutlass_s8s4
from torchao.prototype.blockwise_fp8.blockwise_fp8_gemm_triton import blockwise_fp8_gemm
from torchao.prototype.blockwise_fp8.blockwise_quantization import (
fp8_blockwise_act_quant,
fp8_blockwise_weight_quant,
)
from torchao.quantization.quant_api import (
int8_dynamic_activation_int4_weight,
quantize_,
)


def benchmark_microseconds(f, *args):
return do_bench(lambda: f(*args), return_mode="median") * 1e3

def get_rowwise_problem(m: int, n: int, k: int):
dev = torch.device("cuda")
A = torch.randint(-128, 127, (m, k), dtype=torch.int8, device=dev)
A_scale = torch.randn((m,), dtype=torch.half, device=dev)
B = torch.randint(
-128, 127, size=(n, 4 * k // 8), dtype=torch.int8, device=dev
)
B_scale = torch.randn((n,), dtype=torch.half, device=dev)
C = None

return A, A_scale, B, B_scale, C

def get_blockwise_problem(m: int, n: int, k: int, block_size: int):
assert n % block_size == 0 and k % block_size == 0, "N and K dims must be divisible by block_size"
dev = torch.device("cuda")
A = (448.0 * (2 * torch.rand(m, k, device=dev) - 1)).to(torch.float8_e4m3fn)
A_scale = torch.randn((m, k // block_size), dtype=torch.half, device=dev)
B = (448.0 * (2 * torch.rand(n, k, device=dev) - 1)).to(torch.float8_e4m3fn)
B_scale = torch.randn((n // block_size, k // block_size), dtype=torch.half, device=dev)

return A, A_scale, B, B_scale

def benchmark(m: int, k: int, n: int, block_size: int):
# Speed benchmark
dev = torch.device("cuda")
A_ref = torch.randn((m, k), dtype=torch.half, device=dev)
B_ref = torch.randn((n, k), dtype=torch.half, device=dev)
fp16_time = benchmark_microseconds(torch.nn.functional.linear, A_ref, B_ref)

A, A_scale, B, B_scale, C = get_rowwise_problem(m, n, k)
rowwise_scaled_linear_cutlass_s8s4_time = benchmark_microseconds(
rowwise_scaled_linear_cutlass_s8s4, A, A_scale, B, B_scale, C
)

A, A_scale, B, B_scale = get_blockwise_problem(m, n, k, block_size)
blockwise_fp8_gemm_time = benchmark_microseconds(
blockwise_fp8_gemm, A, A_scale, B, B_scale
)

# Precision benchmark
lin = torch.nn.Linear(k, n, False, dev, torch.half)
A = torch.randn((m, k), dtype=torch.half, device=dev)
W = lin.weight
output = A @ W.T

A_q, A_s = fp8_blockwise_act_quant(A, block_size)
W_q, W_s = fp8_blockwise_weight_quant(W, block_size)
output_blockwise_quant = blockwise_fp8_gemm(A_q, A_s, W_q, W_s)

quantize_(lin, int8_dynamic_activation_int4_weight())
output_rowwise_quant = lin(A)

error_rowwise_quant = compute_error(output, output_rowwise_quant)
error_blockwise_quant = compute_error(output, output_blockwise_quant)

return {
"m": m,
"k": k,
"n": n,
"fp16_latency (ms)": fp16_time,
"rowwise_scaled_linear_cutlass_s8s4 latency (ms)": rowwise_scaled_linear_cutlass_s8s4_time,
"rowwise s8s4 speedup (d/s)": fp16_time / rowwise_scaled_linear_cutlass_s8s4_time,
"blockwise_fp8_gemm latency (ms)": blockwise_fp8_gemm_time,
"blockwise fp8 speedup (d/s)": fp16_time / blockwise_fp8_gemm_time,
"error_rowwise_quant (dB)": error_rowwise_quant,
"error_blockwise_quant (dB)": error_blockwise_quant
}

if __name__ == "__main__":
k_vals = (8192, 8192, 8192, 28672)
n_vals = (8192, 10240, 57344, 8192)
block_size_vals = (128, 128, 128, 128)

results = []
for m in tqdm([1 << i for i in range(10)]):
for n, k, block_size in zip(n_vals, k_vals, block_size_vals):
results.append(benchmark(m, k, n, block_size))

df = pd.DataFrame(results)
df.to_csv("blockwise_scaled_linear_triton_results.csv", index=False)
print(df.to_markdown(index=False))
7 changes: 5 additions & 2 deletions benchmarks/float8/bench_matmul.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,10 +124,13 @@ def run(
if scaling_granularity == ScalingGranularity.TENSORWISE:
scale_a = torch.tensor([1.0], device=device)
scale_b = torch.tensor([1.0], device=device)
else:
assert scaling_granularity == ScalingGranularity.AXISWISE, "unsupported"
elif scaling_granularity == ScalingGranularity.AXISWISE:
scale_a = torch.ones(M, 1, device=device)
scale_b = torch.ones(1, N, device=device)
else:
assert scaling_granularity == ScalingGranularity.BLOCKWISE, "unsupported"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this file is benchmarking torch._scaled_mm which does not support blockwise scaling, is this change intended?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unintended, but I will rework this PR. There were some details that I had missed when I initially worked on it.

scale_a = torch.ones(M, N, device=device)
scale_b = torch.ones(M, N, device=device)

def do_matmul(A, B):
nonlocal scale_a
Expand Down
25 changes: 21 additions & 4 deletions benchmarks/float8/float8_roofline.py
Original file line number Diff line number Diff line change
Expand Up @@ -354,14 +354,30 @@ def run(
m_fp8_dyn_axs = torch.compile(m_fp8_dyn_axs)
fp8_dyn_axs_time_actual_s = get_gpu_kernel_time(m_fp8_dyn_axs, x)

# get the lw recipe scaling gpu kernel time
# get the float8 dynamic blockwise scaling gpu kernel time
torch._dynamo.reset()
config = recipe_name_to_linear_config(Float8LinearRecipeName.ALL_BLOCKWISE)
m_fp8_dyn_blk = convert_to_float8_training(copy.deepcopy(m_orig), config=config)
m_fp8_dyn_blk = torch.compile(m_fp8_dyn_blk)
fp8_dyn_blk_time_actual_s = get_gpu_kernel_time(m_fp8_dyn_blk, x)

# get the lw_axs recipe scaling gpu kernel time
# TODO(future PR): enable below once basic performance issues
# are fixed
# torch._dynamo.reset()
# config = recipe_name_to_linear_config(Float8LinearRecipeName.LW_AXISWISE_WITH_GW_HP)
# m_fp8_lw = convert_to_float8_training(m_orig, config=config)
# m_fp8_lw = torch.compile(m_fp8_lw)
# fp8_lw_time_actual_s = get_gpu_kernel_time(m_fp8_lw, x)
# m_fp8_lw_axs = convert_to_float8_training(m_orig, config=config)
# m_fp8_lw_axs = torch.compile(m_fp8_lw_axs)
# fp8_lw_axs_time_actual_s = get_gpu_kernel_time(m_fp8_lw_axs, x)

# get the lw_blk recipe scaling gpu kernel time
# TODO(future PR): enable below once basic performance issues
# are fixed
# torch._dynamo.reset()
# config = recipe_name_to_linear_config(Float8LinearRecipeName.LW_BLOCKWISE_WITH_GW_HP)
# m_fp8_lw_blk = convert_to_float8_training(m_orig, config=config)
# m_fp8_lw_blk = torch.compile(m_fp8_lw_blk)
# fp8_lw_blk_time_actual_s = get_gpu_kernel_time(m_fp8_lw_blk, x)

results.append(
[
Expand All @@ -382,6 +398,7 @@ def run(
fp8_dyn_time_actual_s,
fp8_del_time_actual_s,
fp8_dyn_axs_time_actual_s,
fp8_dyn_blk_time_actual_s,
# fp8_lw_time_actual_s,
bf16_time_actual_s / fp8_dyn_time_actual_s,
bf16_time_actual_s / fp8_del_time_actual_s,
Expand Down
61 changes: 61 additions & 0 deletions test/float8/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@
from torchao.float8.float8_python_api import addmm_float8_unwrapped
from torchao.float8.float8_scaling_utils import (
get_maybe_axiswise_dim,
get_maybe_blockwise_size,
hp_tensor_to_float8_dynamic,
)
from torchao.float8.float8_tensor import (
Expand Down Expand Up @@ -178,6 +179,22 @@ def test_axiswise_dynamic_cast(self, shape, axiswise_dim):
sqnr = compute_error(a, a_dq)
assert sqnr >= 25.0

@pytest.mark.parametrize("shape", [(8, 16), (4, 8, 16), (2, 4, 8, 16)])
@pytest.mark.parametrize("blockwise_size", [4])
def test_blockwise_dynamic_cast(self, shape, blockwise_size):
a = torch.randn(*shape, dtype=torch.bfloat16)
linear_mm_config = LinearMMConfig()
a_fp8 = hp_tensor_to_float8_dynamic(
a,
e4m3_dtype,
linear_mm_config,
scaling_granularity=ScalingGranularity.BLOCKWISE,
blockwise_size=blockwise_size,
)
a_dq = a_fp8.to_original_precision()
sqnr = compute_error(a, a_dq)
assert sqnr >= 25.0

def test_axiswise_reshape(self):
a = torch.randn(3, 5, 7, dtype=torch.bfloat16)
linear_mm_config = LinearMMConfig()
Expand Down Expand Up @@ -272,6 +289,48 @@ def test_axiswise_gemm(self, a_shape, a_granularity, b_granularity):
sqnr = compute_error(c_ref, c_fp8_compute)
assert sqnr >= 25.0

@pytest.mark.parametrize("a_shape", [(16, 32), (2, 16, 32), (1, 2, 16, 32)])
@pytest.mark.parametrize(
"a_granularity,b_granularity",
[
(ScalingGranularity.BLOCKWISE, ScalingGranularity.BLOCKWISE),
(ScalingGranularity.BLOCKWISE, ScalingGranularity.TENSORWISE),
(ScalingGranularity.TENSORWISE, ScalingGranularity.BLOCKWISE),
],
)
@unittest.skipIf(not torch.cuda.is_available(), "CUDA not available")
@unittest.skipIf(not is_sm_at_least_90(), "Requires CUDA capability >= 9.0")
def test_blockwise_gemm(self, a_shape, a_granularity, b_granularity):
a = torch.randn(*a_shape, dtype=torch.bfloat16, device="cuda")
b = torch.randn(64, 32, dtype=torch.bfloat16, device="cuda")

linear_mm_config = LinearMMConfig()

a_fp8 = hp_tensor_to_float8_dynamic(
a,
e4m3_dtype,
linear_mm_config,
gemm_input_role=GemmInputRole.INPUT,
scaling_granularity=a_granularity,
blockwise_size=get_maybe_blockwise_size(8, a_granularity),
)
a_fp8 = a_fp8.reshape(-1, a_shape[-1])

b_fp8 = hp_tensor_to_float8_dynamic(
b,
e4m3_dtype,
linear_mm_config,
gemm_input_role=GemmInputRole.WEIGHT,
scaling_granularity=b_granularity,
blockwise_size=get_maybe_blockwise_size(8, b_granularity),
)

c_fp8_compute = torch.mm(a_fp8, b_fp8.t())
a = a.reshape(-1, a_shape[-1])
c_ref = torch.mm(a, b.t())
sqnr = compute_error(c_ref, c_fp8_compute)
assert sqnr >= 25.0


class TestFloat8Linear:
def _test_linear_impl(
Expand Down Expand Up @@ -417,7 +476,9 @@ def test_linear_from_config_params(
"recipe_name",
[
Float8LinearRecipeName.ALL_AXISWISE,
Float8LinearRecipeName.ALL_BLOCKWISE,
Float8LinearRecipeName.LW_AXISWISE_WITH_GW_HP,
Float8LinearRecipeName.LW_BLOCKWISE_WITH_GW_HP,
],
)
@pytest.mark.parametrize("x_shape", [(16, 16), (2, 16, 16), (3, 2, 16, 16)])
Expand Down
2 changes: 2 additions & 0 deletions test/float8/test_compile.py
Original file line number Diff line number Diff line change
Expand Up @@ -223,7 +223,9 @@ def test_inductor_from_config_params(
"recipe_name",
[
Float8LinearRecipeName.ALL_AXISWISE,
Float8LinearRecipeName.ALL_BLOCKWISE,
Float8LinearRecipeName.LW_AXISWISE_WITH_GW_HP,
Float8LinearRecipeName.LW_BLOCKWISE_WITH_GW_HP,
],
)
@unittest.skipIf(
Expand Down
2 changes: 2 additions & 0 deletions test/float8/test_numerics_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,7 +199,9 @@ def test_encoder_fw_bw_from_config_params(
"recipe_name",
[
Float8LinearRecipeName.ALL_AXISWISE,
Float8LinearRecipeName.ALL_BLOCKWISE,
Float8LinearRecipeName.LW_AXISWISE_WITH_GW_HP,
Float8LinearRecipeName.LW_BLOCKWISE_WITH_GW_HP,
],
)
@pytest.mark.skipif(
Expand Down
51 changes: 51 additions & 0 deletions test/prototype/test_blockwise_triton.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
import pytest
import torch

from torchao.prototype.blockwise_fp8.blockwise_fp8_gemm_triton import blockwise_fp8_gemm
from torchao.prototype.blockwise_fp8.blockwise_quantization import (
fp8_blockwise_act_quant,
fp8_blockwise_weight_dequant,
fp8_blockwise_weight_quant,
)

ROWWISE_SCALED_LINEAR_CUTLASS_SIZE_MNK = [
(2, 512, 128),
(3, 2048, 2048),
(4, 3584, 640),
(13, 8704, 8576),
(26, 18944, 1664),
(67, 6656, 1408),
]

@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
@pytest.mark.parametrize("_, N, K", ROWWISE_SCALED_LINEAR_CUTLASS_SIZE_MNK)
def test_quant_dequant(_, N, K):
x = torch.randn(N, K).cuda()
qx, s = fp8_blockwise_weight_quant(x, block_size=128)
x_reconstructed = fp8_blockwise_weight_dequant(qx, s, block_size=128)
error = torch.norm(x - x_reconstructed) / torch.norm(x)
print(f"Relative Error: {error.item():.6f}")

assert error < 0.05, "Quant-Dequant error is too high"

@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
@pytest.mark.parametrize("M, N, K", ROWWISE_SCALED_LINEAR_CUTLASS_SIZE_MNK)
def test_blockwise_fp8_gemm(M, N, K):
A = torch.randn(M, K).cuda()
B = torch.randn(N, K).cuda()

C = A @ B.T

A_q, A_s = fp8_blockwise_act_quant(A, block_size=128)
B_q, B_s = fp8_blockwise_weight_quant(B, block_size=128)

C_q = blockwise_fp8_gemm(A_q, A_s, B_q, B_s)
print(C_q, C)
error = torch.norm(C - C_q) / torch.norm(C)
print(f"Relative Error: {error.item():.6f}")

assert error < 0.05, "Quantize gemm error is too high"


# test_quant_dequant()
# test_blockwise_fp8_gemm()
Loading