Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions csrc/quantization/w8a8/int8/scaled_quant.cu
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
#include <ATen/cuda/CUDAContext.h>
#include <torch/all.h>
#include <c10/cuda/CUDAGuard.h>

#include <cmath>

Expand Down Expand Up @@ -275,6 +276,7 @@ void static_scaled_int8_quant(torch::Tensor& out, // [..., hidden_size]
int const num_tokens = input.numel() / hidden_size;
dim3 const grid(num_tokens);
dim3 const block(std::min(hidden_size, 256));
const at::cuda::OptionalCUDAGuard device_guard(device_of(input));
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
VLLM_DISPATCH_FLOATING_TYPES(
input.scalar_type(), "static_scaled_int8_quant_kernel", [&] {
Expand Down Expand Up @@ -306,6 +308,7 @@ void dynamic_scaled_int8_quant(
int const num_tokens = input.numel() / hidden_size;
dim3 const grid(num_tokens);
dim3 const block(std::min(hidden_size, 256));
const at::cuda::OptionalCUDAGuard device_guard(device_of(input));
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
VLLM_DISPATCH_FLOATING_TYPES(
input.scalar_type(), "dynamic_scaled_int8_quant_kernel", [&] {
Expand Down
35 changes: 25 additions & 10 deletions tests/kernels/core/test_fused_quant_layernorm.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@

DTYPES = [torch.bfloat16, torch.float]
QUANT_DTYPES = [torch.int8, torch.float8_e4m3fn]
VEC_HIDDEN_SIZES = range(1024, 1030)
VEC_HIDDEN_SIZES = [1024, 1025, 1027, 1029]
# Avoid combinatorial explosion with full Cartesian product
NUM_TOKENS_HIDDEN_SIZES = [
*[(1, i) for i in [1, 64, *VEC_HIDDEN_SIZES, 5120, 5137]],
Expand Down Expand Up @@ -65,7 +65,7 @@ def ref_dynamic_per_token_quant(
)
else:
assert quant_dtype == torch.int8
torch_out, scales = ops.scaled_int8_quant(torch_out)
torch_out, scales, _ = ops.scaled_int8_quant(torch_out)

return torch_out, scales, residual

Expand Down Expand Up @@ -109,7 +109,7 @@ def ops_impl(

@pytest.mark.parametrize("num_tokens, hidden_size", NUM_TOKENS_HIDDEN_SIZES)
@pytest.mark.parametrize("add_residual", ADD_RESIDUAL)
@pytest.mark.parametrize("scale_ub", SCALE_UBS)
@pytest.mark.parametrize("has_scale_ub", SCALE_UBS)
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("quant_dtype", QUANT_DTYPES)
@pytest.mark.parametrize("seed", SEEDS)
Expand All @@ -119,7 +119,7 @@ def test_rms_norm(
num_tokens: int,
hidden_size: int,
add_residual: bool,
scale_ub: bool,
has_scale_ub: bool,
dtype: torch.dtype,
quant_dtype: torch.dtype,
seed: int,
Expand All @@ -130,7 +130,7 @@ def test_rms_norm(
torch.cuda.manual_seed(seed)
torch.set_default_device(device)

if scale_ub is not None and quant_dtype != torch.float8_e4m3fn:
if has_scale_ub and quant_dtype != torch.float8_e4m3fn:
# skip
return

Expand All @@ -143,9 +143,11 @@ def test_rms_norm(
scale = 1 / (hidden_size)
x = torch.randn(num_tokens, hidden_size, dtype=dtype) * scale
residual = torch.randn_like(x) * scale if add_residual else None
if scale_ub is not None:
if has_scale_ub:
rms_x, _ = ref_rms_norm(layer, x, residual)
scale_ub = torch.mean(rms_x).to(dtype=torch.float32, device="cuda")
else:
scale_ub = None

ref_out, ref_scales, ref_residual = ref_impl(
layer, x, quant_dtype, residual, scale_ub
Expand All @@ -156,14 +158,27 @@ def test_rms_norm(

assert ref_out.dtype == quant_dtype
assert ops_out.dtype == quant_dtype
assert torch.allclose(ref_scales, ops_scales)
if quant_dtype == torch.int8:
assert torch.allclose(ref_scales, ops_scales, atol=1e-6)
# big atol to account for round-off errors.
assert torch.allclose(ref_out, ops_out, atol=1)
else:
assert torch.allclose(
ref_out.to(dtype=torch.float32), ops_out.to(dtype=torch.float32)
)
assert torch.allclose(ref_scales, ops_scales)
a = ref_out.to(dtype=torch.float32)
b = ops_out.to(dtype=torch.float32)
ok = torch.allclose(a, b)
if not ok:
# fallback: compare dequantized values with relaxed tolerance
a_deq = a * ref_scales.view(-1, 1)
b_deq = b * ops_scales.view(-1, 1)
# NOTE: It is possible that some future test cases trigger this
# max diff due to precision issues. If such an error is
# encountered, it's recommended to inspect the differences between
# all corresponding elements from each tensor (e.g. by looping over
# them) and checking how many the max diff error shows up on (just
# a few bad elements should still be considered acceptable).
ok = torch.allclose(a_deq, b_deq, rtol=5e-2, atol=5e-2)
assert ok
if add_residual:
assert torch.allclose(ref_residual, ops_residual)

Expand Down