diff --git a/test/float8/test_dtensor.py b/test/float8/test_dtensor.py index 472d254b47..5985a3f5b5 100644 --- a/test/float8/test_dtensor.py +++ b/test/float8/test_dtensor.py @@ -15,8 +15,6 @@ import pytest import torch -import torch.nn as nn -import torch.nn.functional as F from torchao.utils import TORCH_VERSION_AT_LEAST_2_5 @@ -49,6 +47,7 @@ ) from torchao.float8.float8_utils import e4m3_dtype, tensor_to_scale from torchao.float8.fsdp_utils import WeightWithDynamicFloat8CastTensor +from torchao.testing.float8.dtensor_utils import ToyModel def setup_distributed(): @@ -59,28 +58,6 @@ def setup_distributed(): return device_mesh -class FeedForward(nn.Module): - """MLP based model""" - - def __init__(self): - super(FeedForward, self).__init__() - self.w1 = nn.Linear(16, 32, bias=False) - self.w2 = nn.Linear(16, 32, bias=False) - self.out_proj = nn.Linear(32, 16, bias=False) - - def forward(self, x): - return self.out_proj(F.silu(self.w1(x)) * self.w2(x)) - - -class ToyModel(nn.Module): - def __init__(self): - super(ToyModel, self).__init__() - self.ffn = FeedForward() - - def forward(self, x): - return self.ffn(x) - - def _test_scaled_mm(mesh: DeviceMesh, size=16): device = mesh.device_type fp8_dtype = e4m3_dtype diff --git a/test/float8/test_dtensor.sh b/test/float8/test_dtensor.sh index 2e38feffec..585a9014b1 100755 --- a/test/float8/test_dtensor.sh +++ b/test/float8/test_dtensor.sh @@ -8,4 +8,8 @@ if python -c 'import torch;print(torch.cuda.is_available())' | grep -q "False"; exit fi +# integration tests for TP/SP NCCL_DEBUG=WARN torchrun --nproc_per_node 2 test/float8/test_dtensor.py + +# integration smoke tests for FSDP2 + TP +NCCL_DEBUG=WARN torchrun --nproc_per_node 4 test/float8/test_fsdp2_tp.py diff --git a/test/float8/test_fsdp2_tp.py b/test/float8/test_fsdp2_tp.py new file mode 100644 index 0000000000..fa3d30410b --- /dev/null +++ b/test/float8/test_fsdp2_tp.py @@ -0,0 +1,121 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD 3-Clause license found in the +# LICENSE file in the root directory of this source tree. +""" +Test numerics of manually defined float16 TP vs float8 TP of toy models + +Note: for now, this does not run in CI. +TODO(future): make this run in CI +""" + +import copy +import os + +import pytest +import torch + +from torchao.utils import TORCH_VERSION_AT_LEAST_2_5 + +if not TORCH_VERSION_AT_LEAST_2_5: + pytest.skip("Unsupported PyTorch version", allow_module_level=True) + +from torch.distributed._composable.fsdp import fully_shard +from torch.distributed.device_mesh import DeviceMesh, init_device_mesh +from torch.distributed.tensor.parallel import parallelize_module +from tqdm import tqdm + +from torchao.float8 import Float8LinearConfig +from torchao.float8.float8_linear_utils import convert_to_float8_training +from torchao.float8.float8_tensor_parallel import ( + Float8ColwiseParallel, + Float8RowwiseParallel, +) +from torchao.testing.float8.dtensor_utils import ToyModel + + +def setup_distributed(): + world_size = int(os.environ.get("WORLD_SIZE", -1)) + + # https://pytorch.org/tutorials/recipes/distributed_device_mesh.html + device_mesh = init_device_mesh( + "cuda", + (world_size // 2, 2), + mesh_dim_names=("dp", "tp"), + ) + # seed must be the same in all processes + torch.manual_seed(1) + return device_mesh + + +def _test_fp8_mlp_tensor_parallelism_base( + mesh: DeviceMesh, size=16, compile: bool = False +): + device = mesh.device_type + + config = Float8LinearConfig( + emulate=True, + enable_fsdp_float8_all_gather=True, + ) + + toy_model = ToyModel().to(device) + + tp_model = copy.deepcopy(toy_model) + tp_model = convert_to_float8_training(tp_model, config=config) + + # apply TP + tp_model = parallelize_module( + tp_model, + mesh["tp"], + { + "ffn.w1": Float8ColwiseParallel(), + "ffn.w2": Float8ColwiseParallel(), + "ffn.out_proj": Float8RowwiseParallel(), + }, + ) + + if compile: + tp_model = torch.compile(tp_model) + + # apply FSDP + fsdp_config = {"mesh": mesh["dp"]} + tp_model = fully_shard(tp_model, **fsdp_config) + + x_fp32 = torch.rand(size, size * 2, size, device=device, requires_grad=False) + x_fp32_tp_input = x_fp32.clone() + + tp_out = tp_model(x_fp32_tp_input) + tp_out.sum().backward() + torch.cuda.synchronize() + + # TODO(future PR): test numerics, and add more cases + + +def _test_fp8_mlp_tensor_parallelism_eager(mesh: DeviceMesh, size=16): + _test_fp8_mlp_tensor_parallelism_base(mesh, size, compile=False) + + +def _test_fp8_mlp_tensor_parallelism_compile(mesh: DeviceMesh, size=16): + _test_fp8_mlp_tensor_parallelism_base(mesh, size, compile=True) + + +if __name__ == "__main__": + # float8 only works on CUDA H100 so we only test cuda and we follow + # other test files to not use TestCase but instead just add the test + # cases in the main func. + device_mesh = setup_distributed() + + tests = [ + _test_fp8_mlp_tensor_parallelism_eager, + _test_fp8_mlp_tensor_parallelism_compile, + ] + + for test in tqdm(tests, desc="Running tests"): + try: + test(device_mesh) + except Exception as e: + print(f"Test {test.__name__} failed with error: {e}") + raise e + + torch.distributed.destroy_process_group() diff --git a/torchao/float8/distributed_utils.py b/torchao/float8/distributed_utils.py index 4c0b36c35d..cd1560fabd 100644 --- a/torchao/float8/distributed_utils.py +++ b/torchao/float8/distributed_utils.py @@ -3,110 +3,25 @@ # # This source code is licensed under the BSD 3-Clause license found in the # LICENSE file in the root directory of this source tree. -from typing import Any import torch -from fairscale.nn.model_parallel.initialize import get_model_parallel_group +import torch.distributed._functional_collectives as funcol +from torch.distributed._tensor import DTensor -# from float8_tensor import Float8Tensor from torchao.float8.float8_tensor import Float8Tensor -# additional differentiable distributed primitives for SP which are not in -# the Fairscale codebase - -def _gather_along_first_dim(input_: torch.Tensor): - # same as https://github.com/facebookresearch/fairscale/blob/main/fairscale/nn/model_parallel/mappings.py#L67, - # but gather along first dim instead of last dim - group = get_model_parallel_group() - - # Bypass the function if we are using only 1 GPU. - if torch.distributed.get_world_size(group=group) == 1: - return input_ - - # Size and dimension. - first_dim = 0 - rank = torch.distributed.get_rank(group=group) - world_size = torch.distributed.get_world_size(group=group) - - # If the input is a float8 tensor, we need to do the transformation on the - # inner tensor and then return a new wrapper. - def _transform(t): - # tensors must be contiguous for all_gather to work - input_contig = t.contiguous() - - tensor_list = [torch.empty_like(input_contig) for _ in range(world_size)] - tensor_list[rank] = input_contig - torch.distributed.all_gather(tensor_list, input_contig, group=group) - - # Note: torch.cat already creates a contiguous tensor. - output = torch.cat(tensor_list, dim=first_dim).contiguous() - return output - - if isinstance(input_, Float8Tensor): - new_data = input_._data - new_data = new_data.view(torch.int8) - new_data = _transform(new_data) - new_data = new_data.view(input_._data.dtype) - output = Float8Tensor(new_data, input_._scale, input_._orig_dtype) - else: - output = _transform(input_) - - return output - - -def _reduce_scatter(ctx: Any, input_: torch.Tensor): - group = get_model_parallel_group() - world_size = torch.distributed.get_world_size(group) - - assert input_.shape[0] % world_size == 0 - output_shape = (input_.shape[0] // world_size, *input_.shape[1:]) - output = torch.empty(*output_shape, device=input_.device, dtype=input_.dtype) - - torch.distributed.reduce_scatter_tensor(output, input_, group=group) - return output - - -def _split_along_first_dim(input_: torch.Tensor): - # this is needed for testing - - # like fairscale.nn.model_parallel.mappings._split, but - # along the first dim instead of last dim - - group = get_model_parallel_group() - local_rank = torch.distributed.get_rank(group) - world_size = torch.distributed.get_world_size(group) - - assert input_.shape[0] % world_size == 0 - input_list = torch.split(input_, input_.shape[0] // world_size) - return input_list[local_rank] - - -class _AllGatherFloat8FwReduceScatterBw(torch.autograd.Function): - @staticmethod - def forward(ctx, input_): - return _gather_along_first_dim(input_) - - @staticmethod - def backward(ctx, grad_output): - return _reduce_scatter(ctx, grad_output) - - -class _ReduceScatterFwAllGatherFloat8Bw(torch.autograd.Function): - @staticmethod - def forward(ctx, input_): - return _reduce_scatter(ctx, input_) - - @staticmethod - def backward(ctx, grad_output): - return _gather_along_first_dim(grad_output) - - -class _AllGatherFwSplitBw(torch.autograd.Function): - @staticmethod - def forward(ctx, input_): - return _gather_along_first_dim(input_) - - @staticmethod - def backward(ctx, grad_output): - return _split_along_first_dim(grad_output) +def tensor_already_casted_to_fp8(tensor: torch.Tensor) -> bool: + """ + Check if the tensor is already casted to fp8, works if the local + tensor is wrapped in DTensor. + """ + if isinstance(tensor, Float8Tensor): + return True + elif isinstance(tensor, DTensor): + # TODO: shall we stick to public API and directly use tensor.to_local() here? + return tensor_already_casted_to_fp8(tensor._local_tensor) + elif isinstance(tensor, funcol.AsyncCollectiveTensor): + return tensor_already_casted_to_fp8(tensor.elem) + + return False diff --git a/torchao/float8/float8_linear.py b/torchao/float8/float8_linear.py index f0c8cc2b55..4e76609f56 100644 --- a/torchao/float8/float8_linear.py +++ b/torchao/float8/float8_linear.py @@ -13,6 +13,7 @@ import torch.utils.checkpoint as checkpoint from torchao.float8.config import Float8LinearConfig, ScalingGranularity, ScalingType +from torchao.float8.distributed_utils import tensor_already_casted_to_fp8 from torchao.float8.float8_scaling_utils import ( NoopFwToFloat8E5M2BwDelayed, NoopFwToFloat8E5M2BwDynamic, @@ -469,7 +470,7 @@ def cast_input_to_float8( return input_fp8 def get_weight_scale(self, weight: torch.Tensor) -> Optional[torch.Tensor]: - if isinstance(weight, Float8Tensor): + if tensor_already_casted_to_fp8(weight): return None if self.scaling_type_weight is ScalingType.DELAYED: scale_fn_name = self.config.delayed_scaling_config.scale_fn_name @@ -497,7 +498,7 @@ def cast_weight_to_float8_t( is_amax_initialized: bool, weight_scale: Optional[torch.Tensor] = None, ) -> torch.Tensor: - if isinstance(weight, Float8Tensor): + if tensor_already_casted_to_fp8(weight): return weight.t() weight_fp8 = hp_tensor_and_scale_to_float8( weight, diff --git a/torchao/float8/float8_scaling_utils.py b/torchao/float8/float8_scaling_utils.py index fa5eff733f..6c6f9ca7ca 100644 --- a/torchao/float8/float8_scaling_utils.py +++ b/torchao/float8/float8_scaling_utils.py @@ -13,12 +13,12 @@ import torch from torchao.float8.config import ScalingGranularity +from torchao.float8.distributed_utils import tensor_already_casted_to_fp8 from torchao.float8.float8_tensor import ( Float8Tensor, GemmInputRole, LinearMMConfig, hp_tensor_and_scale_to_float8, - tensor_already_casted_to_fp8, ) from torchao.float8.float8_utils import ( amax_history_to_scale, diff --git a/torchao/float8/float8_tensor.py b/torchao/float8/float8_tensor.py index 21de057fd5..20f40330a8 100644 --- a/torchao/float8/float8_tensor.py +++ b/torchao/float8/float8_tensor.py @@ -7,7 +7,6 @@ from typing import Dict, NamedTuple, Optional import torch -import torch.distributed._functional_collectives as funcol from torch.distributed._tensor import DTensor from torchao.float8.float8_utils import ( @@ -121,21 +120,6 @@ def choose_scaled_mm_config( raise AssertionError(f"unexpected a_role {a_role} and b_role {b_role}") -def tensor_already_casted_to_fp8(tensor: torch.Tensor) -> bool: - """ - Check if the tensor is already casted to fp8 - """ - if isinstance(tensor, Float8Tensor): - return True - elif isinstance(tensor, DTensor): - # TODO: shall we stick to public API and directly use tensor.to_local() here? - return tensor_already_casted_to_fp8(tensor._local_tensor) - elif isinstance(tensor, funcol.AsyncCollectiveTensor): - return tensor_already_casted_to_fp8(tensor.elem) - - return False - - @torch._dynamo.allow_in_graph class _ToFloat8ConstrFunc(torch.autograd.Function): """ diff --git a/torchao/testing/float8/dtensor_utils.py b/torchao/testing/float8/dtensor_utils.py new file mode 100644 index 0000000000..1fab31d850 --- /dev/null +++ b/torchao/testing/float8/dtensor_utils.py @@ -0,0 +1,31 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD 3-Clause license found in the +# LICENSE file in the root directory of this source tree. + +import torch +import torch.nn as nn +import torch.nn.functional as F + + +class FeedForward(nn.Module): + """MLP based model""" + + def __init__(self): + super(FeedForward, self).__init__() + self.w1 = nn.Linear(16, 32, bias=False) + self.w2 = nn.Linear(16, 32, bias=False) + self.out_proj = nn.Linear(32, 16, bias=False) + + def forward(self, x): + return self.out_proj(F.silu(self.w1(x)) * self.w2(x)) + + +class ToyModel(nn.Module): + def __init__(self): + super(ToyModel, self).__init__() + self.ffn = FeedForward() + + def forward(self, x): + return self.ffn(x)