KEMBAR78
ROCm MX-FP8 Gemm by petrex · Pull Request #147553 · pytorch/pytorch · GitHub
Skip to content

Conversation

@petrex
Copy link
Contributor

@petrex petrex commented Feb 20, 2025

TLDR: MX-FP8 matrix multiplications through hipblaslt (require AMD gfx950 && ROCm 6.5+)

This pull request introduces several changes to enhance support for the MX format on ROCm, particularly for the gfx950 device. Key changes include adding validation for matrix dimensions and setting block sizes for the MX format, as well as updating the scaling logic to accommodate new requirements.

Enhancements for MX format on ROCm:

Refactoring and utility functions:

Other changes:

cc @jeffdaily @sunway513 @jithunnair-amd @pruthvistony @ROCmSupport @dllehr-amd @jataylo @hongxiayang @naromero77amd @gujinghui @PenghuiCheng @XiaobingSuper @jianyuh @jgong5 @mingfeima @sanchitintel @ashokei @jingxu10 @min-jean-cho @yanbing-j @Guobing-Chen @Xia-Weiwen @snadampal

@pytorch-bot pytorch-bot bot added the module: rocm AMD GPU support for Pytorch label Feb 20, 2025
@pytorch-bot
Copy link

pytorch-bot bot commented Feb 20, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/147553

Note: Links to docs will display an error until the docs builds have been completed.

❌ 1 New Failure

As of commit 1b40716 with merge base 19a33b2 (image):

NEW FAILURE - The following job has failed:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@petrex petrex marked this pull request as draft February 20, 2025 21:43
@github-actions
Copy link
Contributor

This PR needs a release notes: label

If your changes are user facing and intended to be a part of release notes, please use a label starting with release notes:.

If not, please add the topic: not user facing label.

To add a label, you can comment to pytorchbot, for example
@pytorchbot label "topic: not user facing"

For more information, see
https://github.com/pytorch/pytorch/wiki/PyTorch-AutoLabel-Bot#why-categorize-for-release-notes-and-how-does-it-work.

@petrex
Copy link
Contributor Author

petrex commented Feb 26, 2025

refer to #147548

@pytorch-bot pytorch-bot bot added the module: mkldnn Related to Intel IDEEP or oneDNN (a.k.a. mkldnn) integration label Mar 5, 2025
Peter Y. Yeh added 19 commits April 15, 2025 14:57
…elated functions. Update error messages to reflect ROCm 6.5 compatibility. Add HIP data type mapping for Float4_e2m1fn_x2. Ensure proper version checks for ROCm in CUDA operations.
…hing

- Update IsGfx950Device() to use thread-safe device property caching
- Add ValidateMXFormatRequirements() to check matrix dimension constraints
- Implement comprehensive validation for MX format in scaled_mm
- Improve error messages for MX format requirements
- Refactor scale pointer and block size attribute setting for gfx950
- Move IsGfx950Device() and ValidateMXFormatRequirements() to a new GemmMxUtils.h
- Namespace the utility functions under at::cuda::tunable
- Add include for the new header in Blas.cpp and GemmHipblaslt.h
- Update scaled_gemm signature to include scale dtype parameters for both matrices
- Add support for MX format specific block size configuration for gfx950 devices
- Modify Blas.cpp to pass scale tensor's scalar type to scaled_gemm
- Enhance compatibility with ROCm 6.5+ and gfx950 matrix multiplication
- Extend the list of supported ROCm architectures to include gfx950
- Update the architecture check in Blas.cpp for ROCm version 6.5 and above
- Prepare for compatibility with newer ROCm versions and gfx950 devices
Extend CUDA and ROCm support for scaled matrix multiplication with float8_e8m0fnu scales:
- Add ROCm version check for scaled GEMM support
- Enable scaled GEMM for ROCm 6.5+ with gfx950 architecture
- Remove CUDA version-specific error check for float8_e8m0fnu scales
Update block size attribute setting for scaled matrix multiplication in CUDABlas.cpp:
- Change `matmul` to `computeDesc` when setting block size attributes
- Add TODO comment about potentially needing explicit block size settings for hipblaslt
Extend platform support detection for matrix multiplication with mixed precision:
- Add platform check for ROCm with gfx950 architecture
- Implement conditional support for MX GEMM on ROCm platforms
Enhance platform support detection for scaled matrix multiplication:
- Add a TORCH_CHECK to enforce version requirements for float8_e8m0fnu scales
- Ensure compatibility with CUDA 12.8+ and ROCm 6.5+ with gfx950 architecture
Clarify workspace allocation for ROCm platforms in scaled matrix multiplication:
- Add a TODO comment to highlight potential need for workspace check
- Maintain existing workspace allocation logic for ROCm environments
…d of caching with mutex. This simplifies the code and improves performance for device property checks in GemmMxUtils.h.
@petrex petrex force-pushed the rocm_mx_fp8_scaled_mm branch from e6506c8 to b155013 Compare April 15, 2025 22:28
Peter Y. Yeh added 5 commits April 15, 2025 15:43
…d clarity and performance. Update CUDABlas.cpp and related files to utilize the new check, ensuring consistent validation for MX format requirements across the codebase.
jagadish-amd added a commit to ROCm/pytorch that referenced this pull request Apr 23, 2025
Ported the patch from pytorch#147553
Commented few lines to avoid compilation error. (check for todo comments)

Signed-off-by: Jagadish Krishnamoorthy <jagadish.krishnamoorthy@amd.com>
pruthvistony pushed a commit to ROCm/pytorch that referenced this pull request May 29, 2025
This PR enables mx data type support on ROCm.

Current test stats (accounting only blockwise scale tests)
PYTORCH_TEST_WITH_ROCM=1 python test/test_matmul_cuda.py -k
test_blockwise -v
Ran 452 tests in 17.470s
FAILED (failures=2, errors=2, skipped=337)
_111_ test pass

**fp8 mx data type sample test case.**
PYTORCH_TEST_WITH_ROCM=1 python test/test_matmul_cuda.py
TestFP8MatmulCudaCUDA.test_blockwise_mxfp8_nvfp4_numerics_test_case_name_a_eye_b_eye_fast_accum_False_128_128_128_recipe_mxfp8_cuda
-v

**HipblasLT log** hipblaslt-bench --api_method c -m 128 -n 128 -k 128
--lda 128 --ldb 128 --ldc 128 --ldd 128 --stride_a 0 --stride_b 0
--stride_c 0 --stride_d 0 --alpha 1 --beta 0 --transA T --transB N
--batch_count 1 **--scaleA 3 --scaleB 3 --a_type f8_r --b_type f8_r
-**-c_type bf16_r --d_type bf16_r --compute_type f32_r --algo_method
index --solution_index -2147220478 --rotating 0 --cold_iters 0 --iters 0

**fp4 mx data type sample test case.** 
PYTORCH_TEST_WITH_ROCM=1 python test/test_matmul_cuda.py
TestFP8MatmulCudaCUDA.test_blockwise_mxfp8_nvfp4_numerics_test_case_name_a_eye_b_eye_fast_accum_False_128_128_128_recipe_nvfp4_cuda
-v
**HipblasLT log** hipblaslt-bench --api_method c -m 128 -n 128 -k 128
--lda 128 --ldb 128 --ldc 128 --ldd 128 --stride_a 0 --stride_b 0
--stride_c 0 --stride_d 0 --alpha 1 --beta 0 --transA T --transB N
--batch_count 1 **--scaleA 3 --scaleB** 3 **--a_type f4_r --b_type
f4_r** --c_type bf16_r --d_type bf16_r --compute_type f32_r
--algo_method index --solution_index -2147220478 --rotating 0
--cold_iters 0 --iters 0


Commits:

1. **ROCm MX-FP8 Gemm**  (PR from @petrex )
Ported the patch from pytorch#147553
Commented few lines to avoid compilation error. (check for todo
comments)

2. **Refine _platform_supports_mx_gemm check**

3. **For mx fp8, A and B need not be kFloat8_e8m0fnu type**

4.  **Add fp4 support** (PR from @petrex )
Ported the patch from pytorch#151360
Added fp4 type in aten/src/ATen/cuda/CUDADataType.h
Added more mappings in aten/src/ATen/cuda/CUDADataType.h
Use e8m0 scaling dtype for fp4 test case for ROCm in
test/test_matmul_cuda.py

5.  **test_matmul: change code to correctly skip**
6. **test_matmul: skip if nv format**
skip tests if Matrix dimensions must be multiples of 32.
skip convert to swizzled format
7.  **add fp4 support for data_to_mx_scale**
8. **test_matmul: Add mxfp4 test case**

---------

Signed-off-by: Jagadish Krishnamoorthy <jagadish.krishnamoorthy@amd.com>
@petrex petrex closed this Jun 3, 2025
@petrex
Copy link
Contributor Author

petrex commented Jun 3, 2025

Close in favor of #151360

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

module: mkldnn Related to Intel IDEEP or oneDNN (a.k.a. mkldnn) integration module: rocm AMD GPU support for Pytorch open source

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants