Skip to content

[float8 moe training] make using triton kernels for per-group scaling configurable #2405

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Jun 18, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
# This source code is licensed under the BSD 3-Clause license found in the
# LICENSE file in the root directory of this source tree.
# this benchmarking script is a modified version of the original script from: https://github.com/drisspg/transformer_nuggets/blob/main/transformer_nuggets/utils/benchmark.py

import argparse
import itertools
import time
from dataclasses import dataclass
Expand All @@ -31,7 +31,9 @@ class ExperimentConfig:

@dataclass(frozen=True)
class ExperimentResult:
time_us: float
torch_time_us: float
triton_time_us: bool
triton_speedup: float


@dataclass(frozen=True)
Expand All @@ -41,12 +43,14 @@ class Experiment:


def get_configs() -> List[ExperimentConfig]:
A_shapes = [(2**8, 4096), (2**12, 4096), (2**16, 4096)]
B_shapes = [(4, 4096, 4096), (8, 4096, 4096), (16, 4096, 4096)]
A_shapes = [(2**8, 8192), (2**12, 8192), (2**16, 8192)]
B_shapes = [(4, 8192, 8192), (8, 8192, 8192), (16, 8192, 8192)]
high_precision_dtypes = [torch.bfloat16]
configs = []
for A_shape, B_shape, high_precision_dtype in itertools.product(
A_shapes, B_shapes, high_precision_dtypes
A_shapes,
B_shapes,
high_precision_dtypes,
):
configs.append(
ExperimentConfig(
Expand All @@ -58,7 +62,9 @@ def get_configs() -> List[ExperimentConfig]:
return configs


def run_experiment(config: ExperimentConfig) -> ExperimentResult:
def run_experiment(
config: ExperimentConfig, args: argparse.Namespace
) -> ExperimentResult:
# define test inputs
A = torch.randn(
*config.A_shape,
Expand Down Expand Up @@ -92,26 +98,46 @@ def warmup(func, *args, **kwargs):
for _ in range(10):
func(*args, **kwargs)

def forward_backward(A, B_t, offs):
out = _scaled_grouped_mm(A, B_t, offs=offs, out_dtype=torch.bfloat16)
def forward_backward(A, B_t, offs, use_triton=True):
out = _scaled_grouped_mm(
A,
B_t,
offs=offs,
out_dtype=torch.bfloat16,
use_triton_for_per_group_scales=use_triton,
)
out.sum().backward()
torch.cuda.synchronize()

# bench triton
warmup(forward_backward, A, B_t, offs)
# benchmark torch
torch_func = torch.compile(forward_backward) if args.compile else forward_backward
warmup(torch_func, A, B_t, offs, use_triton=False)
start_time_ns = time.perf_counter_ns()
forward_backward(A, B_t, offs)
time_ns = time.perf_counter_ns() - start_time_ns
time_us = time_ns / 1e3
torch_func(A, B_t, offs, use_triton=False)
torch_time_ns = time.perf_counter_ns() - start_time_ns
torch_time_us = torch_time_ns / 1e3

return ExperimentResult(time_us=time_us)
# benchmark triton
warmup(forward_backward, A, B_t, offs, use_triton=True)
start_time_ns = time.perf_counter_ns()
forward_backward(A, B_t, offs, use_triton=True)
triton_time_ns = time.perf_counter_ns() - start_time_ns
triton_time_us = triton_time_ns / 1e3

return ExperimentResult(
torch_time_us=round(torch_time_us, 3),
triton_time_us=round(triton_time_us, 3),
triton_speedup=round(torch_time_us / triton_time_us, 3),
)


def print_results(experiments: List[Experiment]):
headers = [
"A_shape",
"B_shape",
"high_precision_dtype",
"time_us",
"torch_time_us",
"triton_time_us",
"triton_speedup",
]
rows = []
for experiment in experiments:
Expand All @@ -121,24 +147,28 @@ def print_results(experiments: List[Experiment]):
[
A_shape,
B_shape,
experiment.config.high_precision_dtype,
experiment.result.time_us,
experiment.result.torch_time_us,
experiment.result.triton_time_us,
experiment.result.triton_speedup,
]
)
print(tabulate(rows, headers=headers))


def main():
def main(args: argparse.Namespace):
torch.random.manual_seed(123)
configs = get_configs()
results = []
for config in tqdm(configs):
result = run_experiment(config)
result = run_experiment(config, args)
results.append(Experiment(config=config, result=result))

# Use Tabulate to print results
print_results(results)


if __name__ == "__main__":
main()
arg_parser = argparse.ArgumentParser()
arg_parser.add_argument("--compile", action="store_true")
args = arg_parser.parse_args()
main(args)
11 changes: 8 additions & 3 deletions torchao/prototype/moe_training/conversion_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,8 @@ class MoETrainingConfig(AOBaseConfig):
For all other ops, ScaledGroupedMMTensor behaves like a regular torch.Tensor.
"""

pass
# temporary config flag for testing/benchmarking, will remove before graduating out of prototype
use_triton_for_per_group_scales: bool = True


@register_quantize_module_handler(MoETrainingConfig)
Expand All @@ -46,14 +47,15 @@ def _moe_training_transform(
Returns:
nn.Module: The modified module with swapped parameters.
"""
out = _swap_params(module)
out = _swap_params(module, config=config)
return out


def _swap_params(
module: nn.Module,
*,
module_filter_fn: Optional[Callable[[nn.Module, str], bool]] = None,
config: Optional[MoETrainingConfig] = None,
) -> nn.Module:
"""
Recurses through the nn.Module, recursively swapping the data tensor of
Expand All @@ -69,6 +71,7 @@ def _swap_params(
Returns:
nn.Module: The modified module with swapped linear layers.
"""
use_triton = config.use_triton_for_per_group_scales if config is not None else False
if isinstance(module, nn.Parameter) and (
module_filter_fn is None or module_filter_fn(module, "")
):
Expand All @@ -77,7 +80,9 @@ def _swap_params(
f"Does not support a root nn.Parameter with children: {module}"
)
if not isinstance(module.data, ScaledGroupedMMTensor):
new_data = ScaledGroupedMMTensor(module.data)
new_data = ScaledGroupedMMTensor(
module.data, use_triton_for_per_group_scales=use_triton
)
return nn.Parameter(new_data, requires_grad=module.requires_grad)
return module

Expand Down
48 changes: 43 additions & 5 deletions torchao/prototype/moe_training/scaled_grouped_mm.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,14 +14,19 @@
triton_fp8_col_major_jagged_colwise_scales,
triton_fp8_row_major_jagged_rowwise_scales,
)
from torchao.prototype.moe_training.utils import _is_column_major
from torchao.prototype.moe_training.utils import (
_is_column_major,
_to_2d_jagged_float8_tensor_colwise,
_to_2d_jagged_float8_tensor_rowwise,
)


def _scaled_grouped_mm(
A: torch.Tensor,
B_t: torch.Tensor,
offs: torch.Tensor,
out_dtype: Optional[torch.dtype] = torch.bfloat16,
use_triton_for_per_group_scales: bool = True,
) -> torch.Tensor:
"""
This function performs dynamic float8 quantization with row-wise scaling
Expand All @@ -34,6 +39,7 @@ def _scaled_grouped_mm(
and in column-major memory layout.
offs (int32 torch.Tensor): The offsets to use to mark the starting index of each group along dim0 of the A tensor.
out_dtype (Optional[torch.dtype]): The dtype of the output tensor. Currently only torch.bfloat16 is supported.
use_triton_for_per_group_scales (bool): Whether to use custom triton kernels to compute per-group scales. Default is True.
"""
return _Float8GroupedMM.apply(
A,
Expand All @@ -53,6 +59,7 @@ def forward(
B_t: torch.Tensor,
offs: torch.Tensor,
out_dtype: Optional[torch.dtype] = torch.bfloat16,
use_triton_for_per_group_scales: bool = True,
) -> torch.Tensor:
# torchao _scaled_grouped_mm only supports A=2D, B=3D.
assert A.ndim == 2, "A must be 2D"
Expand Down Expand Up @@ -136,9 +143,16 @@ def forward(
# Store what we need for backward.
ctx.save_for_backward(A, B_fp8_col_major, B_scales, offs)
ctx.out_dtype = out_dtype
ctx.use_triton_for_per_group_scales = use_triton_for_per_group_scales

# Perform scaled grouped GEMM and return result.
# output shape: scaled grouped mm of (M,K) @ (B,K,N) = (M,N)
assert not _is_column_major(A_fp8_row_major), (
"A must be row-major for output = A @ B"
)
assert _is_column_major(B_t_fp8_col_major), (
"B must be column-major for output = A @ B"
)
return torch._scaled_grouped_mm(
A_fp8_row_major,
B_t_fp8_col_major,
Expand All @@ -153,6 +167,7 @@ def forward(
def backward(ctx, grad_output: torch.Tensor):
A, B_fp8_col_major, B_scales, offs = ctx.saved_tensors
out_dtype = ctx.out_dtype
use_triton_for_per_group_scales = ctx.use_triton_for_per_group_scales

# Convert grad_output to float8, row-major for left operand of grouped GEMM
# needed for grad_A: grad_output @ B
Expand All @@ -175,6 +190,12 @@ def backward(ctx, grad_output: torch.Tensor):
#
# grad_A = grad_output @ B
# grad_A = scaled grouped mm of (M,N) @ (B,N,K) = (M,K)
assert not _is_column_major(grad_output_fp8_row_major), (
"grad_output must be row-major for grad_A = grad_output @ B"
)
assert _is_column_major(B_fp8_col_major), (
"B must be column-major for grad_A = grad_output @ B"
)
grad_A = torch._scaled_grouped_mm(
grad_output_fp8_row_major,
B_fp8_col_major,
Expand All @@ -195,25 +216,42 @@ def backward(ctx, grad_output: torch.Tensor):

# grad_B is a special case. both operands of the grouped gemm will be 2D with offsets determing the "groups."
# Compute scales for grad_output_t and A, which are both 2D tensors with offsets which define the "jagged" groups.
per_group_rowwise_scale_func = (
triton_fp8_row_major_jagged_rowwise_scales
if use_triton_for_per_group_scales
else _to_2d_jagged_float8_tensor_rowwise
)
per_group_colwise_scale_func = (
triton_fp8_col_major_jagged_colwise_scales
if use_triton_for_per_group_scales
else _to_2d_jagged_float8_tensor_colwise
)

grad_output_t_fp8_row_major, grad_output_t_scales = (
triton_fp8_row_major_jagged_rowwise_scales(
per_group_rowwise_scale_func(
grad_output_t_row_major,
offs,
output_dtype=torch.float8_e4m3fn,
torch.float8_e4m3fn,
round_scales_to_power_of_2=True,
)
)

A_fp8_col_major, A_scales = triton_fp8_col_major_jagged_colwise_scales(
A_fp8_col_major, A_scales = per_group_colwise_scale_func(
A_col_major,
offs,
output_dtype=torch.float8_e4m3fn,
torch.float8_e4m3fn,
round_scales_to_power_of_2=True,
)

# Compute grad_B = grad_output_t @ A.
# grad_B = grad_output_t @ A
# grad_B = (N,M) @ (M,K) = (N,K)
assert not _is_column_major(grad_output_t_fp8_row_major), (
"grad_output_t must be row-major for grad_B = grad_output_t @ A"
)
assert _is_column_major(A_fp8_col_major), (
"A must be column-major for grad_B = grad_output_t @ A"
)
grad_B = torch._scaled_grouped_mm(
grad_output_t_fp8_row_major,
A_fp8_col_major,
Expand Down
22 changes: 20 additions & 2 deletions torchao/prototype/moe_training/tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,16 @@ class ScaledGroupedMMTensor(torch.Tensor):

grouped_mm_func_name = "_grouped_mm"
offs_arg_name = "offs"
use_triton_for_per_group_scales = True

def __init__(self, data: torch.Tensor):
def __init__(
self, data: torch.Tensor, use_triton_for_per_group_scales: bool = True
):
self._data = data
self._use_triton_for_per_group_scales = use_triton_for_per_group_scales

def __repr__(self):
return f"ScaledGroupedMMTensor(use_triton_for_per_group_scales={self._use_triton_for_per_group_scales}, {self._data})"

@classmethod
def __torch_function__(cls, func, types, args, kwargs={}):
Expand All @@ -31,5 +38,16 @@ def __torch_function__(cls, func, types, args, kwargs={}):
B_is_3d = B.dim() == 3
has_offs = kwargs.get(cls.offs_arg_name) is not None
if A_is_2d and B_is_3d and has_offs:
return _scaled_grouped_mm(*args, **kwargs)
# prefer to use B to check use_triton, as that will be the weight/nn.Parameter
# that is converted to ScaledGroupedMMTensor
use_triton = (
B._use_triton_for_per_group_scales
if isinstance(B, cls)
else A._use_triton_for_per_group_scales
)
return _scaled_grouped_mm(
*args,
use_triton_for_per_group_scales=use_triton,
**kwargs,
)
return super().__torch_function__(func, types, args, kwargs)
Loading