Skip to content

Conversation

vkuzo
Copy link
Contributor

@vkuzo vkuzo commented Feb 19, 2025

Summary:

Continuing the work from #146427

Adds the torch.float8_e8m0fnu dtype to PyTorch, as detailed in
#146414 . Please see the issue for a detailed definition of the format. Example of basic functionality:

import torch

# round trip
x0 = torch.randn(4, 4, dtype=torch.float32)
x1 = x0.to(torch.float8_e8m0fnu)  # RNE rounding
x2 = x1.to(torch.float32)  # 2 ** exponent

# creation with empty
x0 = torch.empty(4, 4, dtype=torch.float8_e8m0fnu)

# printing
print(x0)

Done in this PR:

  • numerical correctness
  • op coverage (except for torch._scaled_mm): create tensor, cast to/from float32
  • printing a tensor works

For future PRs:

  • performance optimizations for casting
  • torch._scaled_mm
  • PT2
  • various cleanups (detailed in comments with issue numbers)

Test Plan:

pytest test/quantization/core/experimental/test_float8.py -s

Reviewers:

Subscribers:

Tasks:

Tags:

cc @jgong5 @mingfeima @XiaobingSuper @sanchitintel @ashokei @jingxu10

Summary:

Adds the `torch.float8_e8m0fnu` dtype to PyTorch, as detailed in
#146414

Not ready for review yet.

Test Plan:

```
pytest test/quantization/core/experimental/test_float8.py -s
```

Reviewers:

Subscribers:

Tasks:

Tags:

ghstack-comment-id: 2634707334
Copy link

pytorch-bot bot commented Feb 19, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/147466

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit d635775 with merge base 303ad19 (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@pytorch-bot pytorch-bot bot added module: cpu CPU specific problem (e.g., perf, algorithm) release notes: quantization release notes category labels Feb 19, 2025
@vkuzo vkuzo changed the title add the torch.float8_e8m0fnu` dtype to PyTorch add the torch.float8_e8m0fnu dtype to PyTorch Feb 19, 2025
@vkuzo
Copy link
Contributor Author

vkuzo commented Feb 19, 2025

@pytorchbot merge

@pytorch-bot pytorch-bot bot added the ciflow/trunk Trigger trunk jobs on your pull request label Feb 19, 2025
@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

@pytorchmergebot
Copy link
Collaborator

Merge failed

Reason: 1 jobs have failed, first few of them are: linux-binary-manywheel / manywheel-py3_9-cuda11_8-build / build

Details for Dev Infra team Raised by workflow job

@vkuzo
Copy link
Contributor Author

vkuzo commented Feb 20, 2025

@pytorchbot merge

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

@henrylhtsang
Copy link
Contributor

henrylhtsang commented Feb 20, 2025

false alarm my bad

mengfei25 added a commit to mengfei25/pytorch that referenced this pull request Mar 6, 2025
jianyizh added a commit to jianyizh/pytorch that referenced this pull request Mar 6, 2025
@github-actions github-actions bot deleted the 20250219_e8m0_intermediate branch March 27, 2025 02:11
pruthvistony pushed a commit to ROCm/pytorch that referenced this pull request Apr 3, 2025
Summary:

Continuing the work from pytorch#146427

Adds the `torch.float8_e8m0fnu` dtype to PyTorch, as detailed in
pytorch#146414 . Please see the issue for a detailed definition of the format.  Example of basic functionality:

```python
import torch

# round trip
x0 = torch.randn(4, 4, dtype=torch.float32)
x1 = x0.to(torch.float8_e8m0fnu)  # RNE rounding
x2 = x1.to(torch.float32)  # 2 ** exponent

# creation with empty
x0 = torch.empty(4, 4, dtype=torch.float8_e8m0fnu)

# printing
print(x0)
```

Done in this PR:
* numerical correctness
* op coverage (except for `torch._scaled_mm`): create tensor, cast to/from float32
* printing a tensor works

For future PRs:
* performance optimizations for casting
* torch._scaled_mm
* PT2
* various cleanups (detailed in comments with issue numbers)

Test Plan:

```
pytest test/quantization/core/experimental/test_float8.py -s
```

Reviewers:

Subscribers:

Tasks:

Tags:

Pull Request resolved: pytorch#147466
Approved by: https://github.com/drisspg
@yiakwy-xpu-ml-framework-team
Copy link

yiakwy-xpu-ml-framework-team commented Apr 23, 2025

Hi @vkuzo are you still working on the problem ? I think the way you used MXFP8_E8M0_FNU could be discussed.

The number is used inside group_quantize function each warp extract 32 fp3_exponent from 32 fp32. So the ocp_fp8e8m0fnu_from_fp32 is simply equal to extract fp32 exponent.

No mantssa and speical number should be taken careful because exponent itself is unsigned. We don't need to care about it.

Fp32 -> fp8 + mxfp8_e8m0_fnu

The second problem is that this mxfp8_e8m0_fnu will be shared by a group consecutive 32 (group) elements (no exponent, only mantissa w/wo implict-1). (because mantissa multiply the exponent is the fp32).

Note fp8 scalar multiple a fp8 data type is as easy as fp32 = fp8_scale << fp32_t::M | (fp8 & fp32_t::INT32_M_MASK));

Here is my implementation :

template<>
HOST_DEVICE_INLINE OutType ocp_fp8e8m0fnu_from_fp32(float fval) {
    using fp32_t = Float;
    using fp8_t = Float8_E8M0_FNU;
    using fp8_storage_t = Float8_E8M0_FNU::Datum;

    union {
        float fval;
        int32_t i32val;
        uint32_t ui32val;
    } val;

    val.fval = fval;

    fp8_storage_t ui8val = (val.i32val & fp32_t::INT32_E_MASK) >> fp32_t::M;
    return fp8_t::from_bits(ui8val.ui8val);
}

@vkuzo
Copy link
Contributor Author

vkuzo commented May 23, 2025

hi @yiakwy-xpu-ml-framework-team , sorry for late reply, I was on a long leave and am now catching up on what I missed.

The number is used inside group_quantize function each warp extract 32 fp3_exponent from 32 fp32. So the ocp_fp8e8m0fnu_from_fp32 is simply equal to extract fp32 exponent.

Correct! Please check out #146414, "E8M0 detailed proposal" for details. The default cast to e8m0 in PyTorch is using RNE to match the IEEE-754 spec and other floating point dtypes, which does not match what is described in the OCP spec. It's up to the user to specify a different casting/rounding behavior (such as floor) if they would like to do so - this is 100% valid.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ciflow/trunk Trigger trunk jobs on your pull request Merged module: cpu CPU specific problem (e.g., perf, algorithm) release notes: quantization release notes category
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants