diff --git a/torchao/prototype/moe_training/benchmarks/benchmark_scaled_grouped_mm.py b/torchao/prototype/moe_training/benchmarks/benchmark_scaled_grouped_mm.py index af1a652fc0..a347763fe6 100644 --- a/torchao/prototype/moe_training/benchmarks/benchmark_scaled_grouped_mm.py +++ b/torchao/prototype/moe_training/benchmarks/benchmark_scaled_grouped_mm.py @@ -4,7 +4,7 @@ # This source code is licensed under the BSD 3-Clause license found in the # LICENSE file in the root directory of this source tree. # this benchmarking script is a modified version of the original script from: https://github.com/drisspg/transformer_nuggets/blob/main/transformer_nuggets/utils/benchmark.py - +import argparse import itertools import time from dataclasses import dataclass @@ -31,7 +31,9 @@ class ExperimentConfig: @dataclass(frozen=True) class ExperimentResult: - time_us: float + torch_time_us: float + triton_time_us: bool + triton_speedup: float @dataclass(frozen=True) @@ -41,12 +43,14 @@ class Experiment: def get_configs() -> List[ExperimentConfig]: - A_shapes = [(2**8, 4096), (2**12, 4096), (2**16, 4096)] - B_shapes = [(4, 4096, 4096), (8, 4096, 4096), (16, 4096, 4096)] + A_shapes = [(2**8, 8192), (2**12, 8192), (2**16, 8192)] + B_shapes = [(4, 8192, 8192), (8, 8192, 8192), (16, 8192, 8192)] high_precision_dtypes = [torch.bfloat16] configs = [] for A_shape, B_shape, high_precision_dtype in itertools.product( - A_shapes, B_shapes, high_precision_dtypes + A_shapes, + B_shapes, + high_precision_dtypes, ): configs.append( ExperimentConfig( @@ -58,7 +62,9 @@ def get_configs() -> List[ExperimentConfig]: return configs -def run_experiment(config: ExperimentConfig) -> ExperimentResult: +def run_experiment( + config: ExperimentConfig, args: argparse.Namespace +) -> ExperimentResult: # define test inputs A = torch.randn( *config.A_shape, @@ -92,26 +98,46 @@ def warmup(func, *args, **kwargs): for _ in range(10): func(*args, **kwargs) - def forward_backward(A, B_t, offs): - out = _scaled_grouped_mm(A, B_t, offs=offs, out_dtype=torch.bfloat16) + def forward_backward(A, B_t, offs, use_triton=True): + out = _scaled_grouped_mm( + A, + B_t, + offs=offs, + out_dtype=torch.bfloat16, + use_triton_for_per_group_scales=use_triton, + ) out.sum().backward() + torch.cuda.synchronize() - # bench triton - warmup(forward_backward, A, B_t, offs) + # benchmark torch + torch_func = torch.compile(forward_backward) if args.compile else forward_backward + warmup(torch_func, A, B_t, offs, use_triton=False) start_time_ns = time.perf_counter_ns() - forward_backward(A, B_t, offs) - time_ns = time.perf_counter_ns() - start_time_ns - time_us = time_ns / 1e3 + torch_func(A, B_t, offs, use_triton=False) + torch_time_ns = time.perf_counter_ns() - start_time_ns + torch_time_us = torch_time_ns / 1e3 - return ExperimentResult(time_us=time_us) + # benchmark triton + warmup(forward_backward, A, B_t, offs, use_triton=True) + start_time_ns = time.perf_counter_ns() + forward_backward(A, B_t, offs, use_triton=True) + triton_time_ns = time.perf_counter_ns() - start_time_ns + triton_time_us = triton_time_ns / 1e3 + + return ExperimentResult( + torch_time_us=round(torch_time_us, 3), + triton_time_us=round(triton_time_us, 3), + triton_speedup=round(torch_time_us / triton_time_us, 3), + ) def print_results(experiments: List[Experiment]): headers = [ "A_shape", "B_shape", - "high_precision_dtype", - "time_us", + "torch_time_us", + "triton_time_us", + "triton_speedup", ] rows = [] for experiment in experiments: @@ -121,19 +147,20 @@ def print_results(experiments: List[Experiment]): [ A_shape, B_shape, - experiment.config.high_precision_dtype, - experiment.result.time_us, + experiment.result.torch_time_us, + experiment.result.triton_time_us, + experiment.result.triton_speedup, ] ) print(tabulate(rows, headers=headers)) -def main(): +def main(args: argparse.Namespace): torch.random.manual_seed(123) configs = get_configs() results = [] for config in tqdm(configs): - result = run_experiment(config) + result = run_experiment(config, args) results.append(Experiment(config=config, result=result)) # Use Tabulate to print results @@ -141,4 +168,7 @@ def main(): if __name__ == "__main__": - main() + arg_parser = argparse.ArgumentParser() + arg_parser.add_argument("--compile", action="store_true") + args = arg_parser.parse_args() + main(args) diff --git a/torchao/prototype/moe_training/conversion_utils.py b/torchao/prototype/moe_training/conversion_utils.py index 928af1cf2e..4d65303b89 100644 --- a/torchao/prototype/moe_training/conversion_utils.py +++ b/torchao/prototype/moe_training/conversion_utils.py @@ -28,7 +28,8 @@ class MoETrainingConfig(AOBaseConfig): For all other ops, ScaledGroupedMMTensor behaves like a regular torch.Tensor. """ - pass + # temporary config flag for testing/benchmarking, will remove before graduating out of prototype + use_triton_for_per_group_scales: bool = True @register_quantize_module_handler(MoETrainingConfig) @@ -46,7 +47,7 @@ def _moe_training_transform( Returns: nn.Module: The modified module with swapped parameters. """ - out = _swap_params(module) + out = _swap_params(module, config=config) return out @@ -54,6 +55,7 @@ def _swap_params( module: nn.Module, *, module_filter_fn: Optional[Callable[[nn.Module, str], bool]] = None, + config: Optional[MoETrainingConfig] = None, ) -> nn.Module: """ Recurses through the nn.Module, recursively swapping the data tensor of @@ -69,6 +71,7 @@ def _swap_params( Returns: nn.Module: The modified module with swapped linear layers. """ + use_triton = config.use_triton_for_per_group_scales if config is not None else False if isinstance(module, nn.Parameter) and ( module_filter_fn is None or module_filter_fn(module, "") ): @@ -77,7 +80,9 @@ def _swap_params( f"Does not support a root nn.Parameter with children: {module}" ) if not isinstance(module.data, ScaledGroupedMMTensor): - new_data = ScaledGroupedMMTensor(module.data) + new_data = ScaledGroupedMMTensor( + module.data, use_triton_for_per_group_scales=use_triton + ) return nn.Parameter(new_data, requires_grad=module.requires_grad) return module diff --git a/torchao/prototype/moe_training/scaled_grouped_mm.py b/torchao/prototype/moe_training/scaled_grouped_mm.py index d3aaf615db..f7d470e556 100644 --- a/torchao/prototype/moe_training/scaled_grouped_mm.py +++ b/torchao/prototype/moe_training/scaled_grouped_mm.py @@ -14,7 +14,11 @@ triton_fp8_col_major_jagged_colwise_scales, triton_fp8_row_major_jagged_rowwise_scales, ) -from torchao.prototype.moe_training.utils import _is_column_major +from torchao.prototype.moe_training.utils import ( + _is_column_major, + _to_2d_jagged_float8_tensor_colwise, + _to_2d_jagged_float8_tensor_rowwise, +) def _scaled_grouped_mm( @@ -22,6 +26,7 @@ def _scaled_grouped_mm( B_t: torch.Tensor, offs: torch.Tensor, out_dtype: Optional[torch.dtype] = torch.bfloat16, + use_triton_for_per_group_scales: bool = True, ) -> torch.Tensor: """ This function performs dynamic float8 quantization with row-wise scaling @@ -34,6 +39,7 @@ def _scaled_grouped_mm( and in column-major memory layout. offs (int32 torch.Tensor): The offsets to use to mark the starting index of each group along dim0 of the A tensor. out_dtype (Optional[torch.dtype]): The dtype of the output tensor. Currently only torch.bfloat16 is supported. + use_triton_for_per_group_scales (bool): Whether to use custom triton kernels to compute per-group scales. Default is True. """ return _Float8GroupedMM.apply( A, @@ -53,6 +59,7 @@ def forward( B_t: torch.Tensor, offs: torch.Tensor, out_dtype: Optional[torch.dtype] = torch.bfloat16, + use_triton_for_per_group_scales: bool = True, ) -> torch.Tensor: # torchao _scaled_grouped_mm only supports A=2D, B=3D. assert A.ndim == 2, "A must be 2D" @@ -136,9 +143,16 @@ def forward( # Store what we need for backward. ctx.save_for_backward(A, B_fp8_col_major, B_scales, offs) ctx.out_dtype = out_dtype + ctx.use_triton_for_per_group_scales = use_triton_for_per_group_scales # Perform scaled grouped GEMM and return result. # output shape: scaled grouped mm of (M,K) @ (B,K,N) = (M,N) + assert not _is_column_major(A_fp8_row_major), ( + "A must be row-major for output = A @ B" + ) + assert _is_column_major(B_t_fp8_col_major), ( + "B must be column-major for output = A @ B" + ) return torch._scaled_grouped_mm( A_fp8_row_major, B_t_fp8_col_major, @@ -153,6 +167,7 @@ def forward( def backward(ctx, grad_output: torch.Tensor): A, B_fp8_col_major, B_scales, offs = ctx.saved_tensors out_dtype = ctx.out_dtype + use_triton_for_per_group_scales = ctx.use_triton_for_per_group_scales # Convert grad_output to float8, row-major for left operand of grouped GEMM # needed for grad_A: grad_output @ B @@ -175,6 +190,12 @@ def backward(ctx, grad_output: torch.Tensor): # # grad_A = grad_output @ B # grad_A = scaled grouped mm of (M,N) @ (B,N,K) = (M,K) + assert not _is_column_major(grad_output_fp8_row_major), ( + "grad_output must be row-major for grad_A = grad_output @ B" + ) + assert _is_column_major(B_fp8_col_major), ( + "B must be column-major for grad_A = grad_output @ B" + ) grad_A = torch._scaled_grouped_mm( grad_output_fp8_row_major, B_fp8_col_major, @@ -195,25 +216,42 @@ def backward(ctx, grad_output: torch.Tensor): # grad_B is a special case. both operands of the grouped gemm will be 2D with offsets determing the "groups." # Compute scales for grad_output_t and A, which are both 2D tensors with offsets which define the "jagged" groups. + per_group_rowwise_scale_func = ( + triton_fp8_row_major_jagged_rowwise_scales + if use_triton_for_per_group_scales + else _to_2d_jagged_float8_tensor_rowwise + ) + per_group_colwise_scale_func = ( + triton_fp8_col_major_jagged_colwise_scales + if use_triton_for_per_group_scales + else _to_2d_jagged_float8_tensor_colwise + ) + grad_output_t_fp8_row_major, grad_output_t_scales = ( - triton_fp8_row_major_jagged_rowwise_scales( + per_group_rowwise_scale_func( grad_output_t_row_major, offs, - output_dtype=torch.float8_e4m3fn, + torch.float8_e4m3fn, round_scales_to_power_of_2=True, ) ) - A_fp8_col_major, A_scales = triton_fp8_col_major_jagged_colwise_scales( + A_fp8_col_major, A_scales = per_group_colwise_scale_func( A_col_major, offs, - output_dtype=torch.float8_e4m3fn, + torch.float8_e4m3fn, round_scales_to_power_of_2=True, ) # Compute grad_B = grad_output_t @ A. # grad_B = grad_output_t @ A # grad_B = (N,M) @ (M,K) = (N,K) + assert not _is_column_major(grad_output_t_fp8_row_major), ( + "grad_output_t must be row-major for grad_B = grad_output_t @ A" + ) + assert _is_column_major(A_fp8_col_major), ( + "A must be column-major for grad_B = grad_output_t @ A" + ) grad_B = torch._scaled_grouped_mm( grad_output_t_fp8_row_major, A_fp8_col_major, diff --git a/torchao/prototype/moe_training/tensor.py b/torchao/prototype/moe_training/tensor.py index 2a929d3b76..8d7a8f815b 100644 --- a/torchao/prototype/moe_training/tensor.py +++ b/torchao/prototype/moe_training/tensor.py @@ -12,9 +12,16 @@ class ScaledGroupedMMTensor(torch.Tensor): grouped_mm_func_name = "_grouped_mm" offs_arg_name = "offs" + use_triton_for_per_group_scales = True - def __init__(self, data: torch.Tensor): + def __init__( + self, data: torch.Tensor, use_triton_for_per_group_scales: bool = True + ): self._data = data + self._use_triton_for_per_group_scales = use_triton_for_per_group_scales + + def __repr__(self): + return f"ScaledGroupedMMTensor(use_triton_for_per_group_scales={self._use_triton_for_per_group_scales}, {self._data})" @classmethod def __torch_function__(cls, func, types, args, kwargs={}): @@ -31,5 +38,16 @@ def __torch_function__(cls, func, types, args, kwargs={}): B_is_3d = B.dim() == 3 has_offs = kwargs.get(cls.offs_arg_name) is not None if A_is_2d and B_is_3d and has_offs: - return _scaled_grouped_mm(*args, **kwargs) + # prefer to use B to check use_triton, as that will be the weight/nn.Parameter + # that is converted to ScaledGroupedMMTensor + use_triton = ( + B._use_triton_for_per_group_scales + if isinstance(B, cls) + else A._use_triton_for_per_group_scales + ) + return _scaled_grouped_mm( + *args, + use_triton_for_per_group_scales=use_triton, + **kwargs, + ) return super().__torch_function__(func, types, args, kwargs)