Skip to content

attempting to run aten.abs.default, this is not supported with latest torchtitan + torchao #1313

@lchu6

Description

@lchu6

Hi, we are hitting an issue of attempting to run aten.abs.default, this is not supported when running torchtitan with 2d (FSDP+TP) and FP8 all gather (enable_fsdp_float8_all_gather=True).

We didn't see this issue in the past when we did the same runs back in Aug. so we did a little more experiments and could confirm that later torchao versions caused this issue (at least on our end), as reverting to a previous git commit (Aug/Sep) makes the issue go away.

To provide a little more info:

  1. fp8 linear only does not trigger this issue (i.e. error only triggered when enabling fp8 allgather as well)
  2. 1d does not trigger this issue (i.e. fsdp only with tp=1 does not trigger error)
  3. older versions of torchao does not have this issue

To reproduce:
Run latest torchtitan + latest torchao as is, with tp>1 and all three fp8 flags (enable_float8_linear, enable_fsdp_float8_all_gather, precompute_float8_dynamic_scale_for_fsdp) True

Error trace:

[rank0]: attempting to run aten.abs.default, this is not supported


[rank0]: from user code:
[rank0]:    File "/proj/data-eng/lchu/miniconda3/envs/latest/lib/python3.10/site-packages/torch/distributed/algorithms/_checkpoint/checkpoint_wrapper.py", line 170, in forward
[rank0]:     return self.checkpoint_fn(  # type: ignore[misc]
[rank0]:   File "/proj/data-eng/lchu/torchtitan_latest/torchtitan/models/llama/model.py", line 324, in forward
[rank0]:     h = x + self.attention(self.attention_norm(x), freqs_cis)
[rank0]:   File "/proj/data-eng/lchu/miniconda3/envs/latest/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1841, in _call_impl
[rank0]:     return inner()
[rank0]:   File "/proj/data-eng/lchu/miniconda3/envs/latest/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1790, in inner
[rank0]:     result = forward_call(*args, **kwargs)
[rank0]:   File "/proj/data-eng/lchu/torchtitan_latest/torchtitan/models/llama/model.py", line 190, in forward
[rank0]:     xq, xk, xv = self.wq(x), self.wk(x), self.wv(x)
[rank0]:   File "/proj/data-eng/lchu/miniconda3/envs/latest/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1841, in _call_impl
[rank0]:     return inner()
[rank0]:   File "/proj/data-eng/lchu/miniconda3/envs/latest/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1790, in inner
[rank0]:     result = forward_call(*args, **kwargs)
[rank0]:   File "/proj/data-eng/lchu/miniconda3/envs/latest/lib/python3.10/site-packages/torchao/float8/float8_linear.py", line 570, in forward
[rank0]:     weight_scale = self.get_weight_scale(self.weight)
[rank0]:   File "/proj/data-eng/lchu/miniconda3/envs/latest/lib/python3.10/site-packages/torchao/float8/float8_linear.py", line 487, in get_weight_scale
[rank0]:     return tensor_to_scale(weight, e4m3_dtype)
[rank0]:   File "/proj/data-eng/lchu/miniconda3/envs/latest/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 116, in decorate_context
[rank0]:     return func(*args, **kwargs)
[rank0]:   File "/proj/data-eng/lchu/miniconda3/envs/latest/lib/python3.10/site-packages/torchao/float8/float8_utils.py", line 138, in tensor_to_scale
[rank0]:     amax = tensor_to_amax(
[rank0]:   File "/proj/data-eng/lchu/miniconda3/envs/latest/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 116, in decorate_context
[rank0]:     return func(*args, **kwargs)
[rank0]:   File "/proj/data-eng/lchu/miniconda3/envs/latest/lib/python3.10/site-packages/torchao/float8/float8_utils.py", line 113, in tensor_to_amax
[rank0]:     amax = torch.max(torch.abs(x))

Metadata

Metadata

Assignees

Labels

No labels
No labels

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions