Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature Request] FP8 Support #156

Open
LeiWang1999 opened this issue Feb 4, 2025 · 12 comments
Open

[Feature Request] FP8 Support #156

LeiWang1999 opened this issue Feb 4, 2025 · 12 comments

Comments

@LeiWang1999
Copy link

It would be great to have FP8 support for converting tensors from PyTorch to DLPack. Currently, both PyTorch and TVM support FP8, but there is no direct way to convert tensors between them. Adding this support would improve interoperability and usability.

@tqchen
Copy link
Member

tqchen commented Feb 4, 2025

yes, we shoiuld do it cc @leofang

@LeiWang1999
Copy link
Author

LeiWang1999 commented Feb 4, 2025

Thanks tq, I found a temporary solution that works for me:

if arg.dtype in {torch.float8_e4m3fn, torch.float8_e4m3fnuz, torch.float8_e5m2, torch.float8_e5m2fnuz}:
      return ndarray.from_dlpack(
          to_dlpack_func(arg.view(torch.int8))
      ).view(dtype=float8_dtype_map[arg.dtype])

@potatomashed
Copy link

potatomashed commented Feb 4, 2025

FP8 (and FP4) has become a common practice in LLM training, and in MLC-Python, we have to extend DLPack to support various FP8 types: https://github.com/mlc-ai/mlc-python/blob/0a22cf87d1888cf39dcd2f856f866c7e1d41a568/include/mlc/c_api.h#L35-L45.

Regarding DLPack support, the trickiest issue of fp8/fp4 is that there are multiple different sub-types, e.g. float8_e3m4 and float8_e4m3 are both fp8, but have different exponent/mantissa. In DLPack design, they will have to take two different dtype codes.

A common standard people may refer to is ml-dtypes package from JAX, which is a superset and consistent with PyTorch's fp8. This is something MLC-Python adopts as well. I'm happy to upstream full fp8 support from MLC-Python back to DLPack if the implementation looks okay.

@tqchen
Copy link
Member

tqchen commented Feb 4, 2025

Seems starting with pytorch type is a reasonable choice, given the available hw support here, contributions are welcomed. We can add additional type code if needed if future needs arise.

@leofang
Copy link
Collaborator

leofang commented Feb 4, 2025

cc @seberg @oleksandr-pavlyk @rgommers for vis, let's try to get this discussed in the array API meeting this week

@leofang
Copy link
Collaborator

leofang commented Feb 4, 2025

I'm happy to upstream full fp8 support from MLC-Python back to DLPack if the implementation looks okay.

@potatomashed I am curious what you meant exactly by "full fp8 support" here, I assume you're referring to the set of additional enumerators needed for representing different fp8 subtypes? Are there things beyond this addition?

@potatomashed
Copy link

@potatomashed I am curious what you meant exactly by "full fp8 support" here, I assume you're referring to the set of additional enumerators needed for representing different fp8 subtypes? Are there things beyond this addition?

Yep that's just a few extra fp8 subtypes.

For reference, in PyTorch 2.6, the following fp8 subtypes are supported:

  • float8_e4m3fn
  • float8_e4m3fnuz
  • float8_e5m2
  • float8_e5m2fnuz

while ml-dtypes has some additional ones:

  • float8_e3m4
  • float8_e4m3
  • float8_e4m3b11fnuz
  • float8_e8m0fnu

@tqchen
Copy link
Member

tqchen commented Feb 4, 2025

Some extra survey and thinkings:

Likely these subtypes from PT are needed.

  • float8_e4m3fn
  • float8_e4m3fnuz
  • float8_e5m2
  • float8_e5m2fnuz

The latest blackwell microscaling seems to start to support float8_e8m0fnu as an scaling factor. That combined with float4_e2m1fn would enable microscaling support for F4. In such case, micro scaling format would contain two DLPack arrays (one float4_e2m1fn for the weights and another for scale)

Would be good to also discuss potential use-cases for other data types. but this could be a good initial list(along with float4_e2m1fn)

@potatomashed
Copy link

potatomashed commented Feb 5, 2025

MX FP training is indeed valid cases, e.g. MXFP4 or asymmetric MXFP4 https://arxiv.org/abs/2411.09909. I don't think we will have to be super over the top speculating future applications, but given the diverse set of existing subbyte dtypes, I'd love to learn DLPack maintainers' principles/rules on which dtypes to include

@tqchen
Copy link
Member

tqchen commented Feb 5, 2025

as of now we focus on reasonably stablized types, mainly because the goal is to enable frameworks to exchange and also remain stable over time.

Notably, MX format usually are stored in two NDArrays, e.g. float8_e8m0fnu group scale + float4_e2m1fn value. That means as of now we can focus on the individual component types, aka float8_e8m0fnu and float4_e2m1fn

@potatomashed
Copy link

Notably, MX format usually are stored in two NDArrays

This assumption doesn't always hold but is a good starting point

@leofang
Copy link
Collaborator

leofang commented Feb 5, 2025

xref: pytorch/pytorch#146414

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

4 participants