-
Notifications
You must be signed in to change notification settings - Fork 319
Closed
Description
Hi, we are hitting an issue of attempting to run aten.abs.default, this is not supported
when running torchtitan with 2d (FSDP+TP) and FP8 all gather (enable_fsdp_float8_all_gather=True
).
We didn't see this issue in the past when we did the same runs back in Aug. so we did a little more experiments and could confirm that later torchao versions caused this issue (at least on our end), as reverting to a previous git commit (Aug/Sep) makes the issue go away.
To provide a little more info:
- fp8 linear only does not trigger this issue (i.e. error only triggered when enabling fp8 allgather as well)
- 1d does not trigger this issue (i.e. fsdp only with tp=1 does not trigger error)
- older versions of torchao does not have this issue
To reproduce:
Run latest torchtitan + latest torchao as is, with tp>1
and all three fp8 flags (enable_float8_linear
, enable_fsdp_float8_all_gather
, precompute_float8_dynamic_scale_for_fsdp
) True
Error trace:
[rank0]: attempting to run aten.abs.default, this is not supported
[rank0]: from user code:
[rank0]: File "/proj/data-eng/lchu/miniconda3/envs/latest/lib/python3.10/site-packages/torch/distributed/algorithms/_checkpoint/checkpoint_wrapper.py", line 170, in forward
[rank0]: return self.checkpoint_fn( # type: ignore[misc]
[rank0]: File "/proj/data-eng/lchu/torchtitan_latest/torchtitan/models/llama/model.py", line 324, in forward
[rank0]: h = x + self.attention(self.attention_norm(x), freqs_cis)
[rank0]: File "/proj/data-eng/lchu/miniconda3/envs/latest/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1841, in _call_impl
[rank0]: return inner()
[rank0]: File "/proj/data-eng/lchu/miniconda3/envs/latest/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1790, in inner
[rank0]: result = forward_call(*args, **kwargs)
[rank0]: File "/proj/data-eng/lchu/torchtitan_latest/torchtitan/models/llama/model.py", line 190, in forward
[rank0]: xq, xk, xv = self.wq(x), self.wk(x), self.wv(x)
[rank0]: File "/proj/data-eng/lchu/miniconda3/envs/latest/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1841, in _call_impl
[rank0]: return inner()
[rank0]: File "/proj/data-eng/lchu/miniconda3/envs/latest/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1790, in inner
[rank0]: result = forward_call(*args, **kwargs)
[rank0]: File "/proj/data-eng/lchu/miniconda3/envs/latest/lib/python3.10/site-packages/torchao/float8/float8_linear.py", line 570, in forward
[rank0]: weight_scale = self.get_weight_scale(self.weight)
[rank0]: File "/proj/data-eng/lchu/miniconda3/envs/latest/lib/python3.10/site-packages/torchao/float8/float8_linear.py", line 487, in get_weight_scale
[rank0]: return tensor_to_scale(weight, e4m3_dtype)
[rank0]: File "/proj/data-eng/lchu/miniconda3/envs/latest/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 116, in decorate_context
[rank0]: return func(*args, **kwargs)
[rank0]: File "/proj/data-eng/lchu/miniconda3/envs/latest/lib/python3.10/site-packages/torchao/float8/float8_utils.py", line 138, in tensor_to_scale
[rank0]: amax = tensor_to_amax(
[rank0]: File "/proj/data-eng/lchu/miniconda3/envs/latest/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 116, in decorate_context
[rank0]: return func(*args, **kwargs)
[rank0]: File "/proj/data-eng/lchu/miniconda3/envs/latest/lib/python3.10/site-packages/torchao/float8/float8_utils.py", line 113, in tensor_to_amax
[rank0]: amax = torch.max(torch.abs(x))
Metadata
Metadata
Assignees
Labels
No labels