-
Notifications
You must be signed in to change notification settings - Fork 25.7k
ROCm MX-FP8 Gemm #147553
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
ROCm MX-FP8 Gemm #147553
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/147553
Note: Links to docs will display an error until the docs builds have been completed. ❌ 1 New FailureAs of commit 1b40716 with merge base 19a33b2 ( NEW FAILURE - The following job has failed:
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
This PR needs a
|
|
refer to #147548 |
…elated functions. Update error messages to reflect ROCm 6.5 compatibility. Add HIP data type mapping for Float4_e2m1fn_x2. Ensure proper version checks for ROCm in CUDA operations.
…hing - Update IsGfx950Device() to use thread-safe device property caching - Add ValidateMXFormatRequirements() to check matrix dimension constraints - Implement comprehensive validation for MX format in scaled_mm - Improve error messages for MX format requirements - Refactor scale pointer and block size attribute setting for gfx950
- Move IsGfx950Device() and ValidateMXFormatRequirements() to a new GemmMxUtils.h - Namespace the utility functions under at::cuda::tunable - Add include for the new header in Blas.cpp and GemmHipblaslt.h
- Update scaled_gemm signature to include scale dtype parameters for both matrices - Add support for MX format specific block size configuration for gfx950 devices - Modify Blas.cpp to pass scale tensor's scalar type to scaled_gemm - Enhance compatibility with ROCm 6.5+ and gfx950 matrix multiplication
- Extend the list of supported ROCm architectures to include gfx950 - Update the architecture check in Blas.cpp for ROCm version 6.5 and above - Prepare for compatibility with newer ROCm versions and gfx950 devices
Extend CUDA and ROCm support for scaled matrix multiplication with float8_e8m0fnu scales: - Add ROCm version check for scaled GEMM support - Enable scaled GEMM for ROCm 6.5+ with gfx950 architecture - Remove CUDA version-specific error check for float8_e8m0fnu scales
Update block size attribute setting for scaled matrix multiplication in CUDABlas.cpp: - Change `matmul` to `computeDesc` when setting block size attributes - Add TODO comment about potentially needing explicit block size settings for hipblaslt
Extend platform support detection for matrix multiplication with mixed precision: - Add platform check for ROCm with gfx950 architecture - Implement conditional support for MX GEMM on ROCm platforms
Enhance platform support detection for scaled matrix multiplication: - Add a TORCH_CHECK to enforce version requirements for float8_e8m0fnu scales - Ensure compatibility with CUDA 12.8+ and ROCm 6.5+ with gfx950 architecture
Clarify workspace allocation for ROCm platforms in scaled matrix multiplication: - Add a TODO comment to highlight potential need for workspace check - Maintain existing workspace allocation logic for ROCm environments
…d of caching with mutex. This simplifies the code and improves performance for device property checks in GemmMxUtils.h.
e6506c8 to
b155013
Compare
…d clarity and performance. Update CUDABlas.cpp and related files to utilize the new check, ensuring consistent validation for MX format requirements across the codebase.
Ported the patch from pytorch#147553 Commented few lines to avoid compilation error. (check for todo comments) Signed-off-by: Jagadish Krishnamoorthy <jagadish.krishnamoorthy@amd.com>
This PR enables mx data type support on ROCm. Current test stats (accounting only blockwise scale tests) PYTORCH_TEST_WITH_ROCM=1 python test/test_matmul_cuda.py -k test_blockwise -v Ran 452 tests in 17.470s FAILED (failures=2, errors=2, skipped=337) _111_ test pass **fp8 mx data type sample test case.** PYTORCH_TEST_WITH_ROCM=1 python test/test_matmul_cuda.py TestFP8MatmulCudaCUDA.test_blockwise_mxfp8_nvfp4_numerics_test_case_name_a_eye_b_eye_fast_accum_False_128_128_128_recipe_mxfp8_cuda -v **HipblasLT log** hipblaslt-bench --api_method c -m 128 -n 128 -k 128 --lda 128 --ldb 128 --ldc 128 --ldd 128 --stride_a 0 --stride_b 0 --stride_c 0 --stride_d 0 --alpha 1 --beta 0 --transA T --transB N --batch_count 1 **--scaleA 3 --scaleB 3 --a_type f8_r --b_type f8_r -**-c_type bf16_r --d_type bf16_r --compute_type f32_r --algo_method index --solution_index -2147220478 --rotating 0 --cold_iters 0 --iters 0 **fp4 mx data type sample test case.** PYTORCH_TEST_WITH_ROCM=1 python test/test_matmul_cuda.py TestFP8MatmulCudaCUDA.test_blockwise_mxfp8_nvfp4_numerics_test_case_name_a_eye_b_eye_fast_accum_False_128_128_128_recipe_nvfp4_cuda -v **HipblasLT log** hipblaslt-bench --api_method c -m 128 -n 128 -k 128 --lda 128 --ldb 128 --ldc 128 --ldd 128 --stride_a 0 --stride_b 0 --stride_c 0 --stride_d 0 --alpha 1 --beta 0 --transA T --transB N --batch_count 1 **--scaleA 3 --scaleB** 3 **--a_type f4_r --b_type f4_r** --c_type bf16_r --d_type bf16_r --compute_type f32_r --algo_method index --solution_index -2147220478 --rotating 0 --cold_iters 0 --iters 0 Commits: 1. **ROCm MX-FP8 Gemm** (PR from @petrex ) Ported the patch from pytorch#147553 Commented few lines to avoid compilation error. (check for todo comments) 2. **Refine _platform_supports_mx_gemm check** 3. **For mx fp8, A and B need not be kFloat8_e8m0fnu type** 4. **Add fp4 support** (PR from @petrex ) Ported the patch from pytorch#151360 Added fp4 type in aten/src/ATen/cuda/CUDADataType.h Added more mappings in aten/src/ATen/cuda/CUDADataType.h Use e8m0 scaling dtype for fp4 test case for ROCm in test/test_matmul_cuda.py 5. **test_matmul: change code to correctly skip** 6. **test_matmul: skip if nv format** skip tests if Matrix dimensions must be multiples of 32. skip convert to swizzled format 7. **add fp4 support for data_to_mx_scale** 8. **test_matmul: Add mxfp4 test case** --------- Signed-off-by: Jagadish Krishnamoorthy <jagadish.krishnamoorthy@amd.com>
|
Close in favor of #151360 |
TLDR: MX-FP8 matrix multiplications through hipblaslt (require AMD gfx950 && ROCm 6.5+)
This pull request introduces several changes to enhance support for the MX format on ROCm, particularly for the gfx950 device. Key changes include adding validation for matrix dimensions and setting block sizes for the MX format, as well as updating the scaling logic to accommodate new requirements.
Enhancements for MX format on ROCm:
aten/src/ATen/cuda/CUDABlas.cpp: Added validation for matrix dimensions and set block sizes for MX format when using ROCm version 6.5 or later on gfx950 devices. [1] [2]aten/src/ATen/cuda/tunable/GemmHipblaslt.h: Included validation and block size settings for MX format in theHipblasltGemmOpclass.aten/src/ATen/native/cuda/Blas.cpp: Added validation for MX format requirements and updated scaling logic for block-wise scaling on gfx950 devices. [1] [2]Refactoring and utility functions:
aten/src/ATen/cuda/tunable/GemmMxUtils.h: Introduced helper functionsIsGfx950DeviceandValidateMXFormatRequirementsto cache device properties and validate MX format requirements.aten/src/ATen/native/cuda/Blas.cpp: Added a helper functionIsGfx950Deviceto cache device properties and updated the_scaled_mm_out_cudafunction to include MX format validation. [1] [2]Other changes:
torch/testing/_internal/common_cuda.py: Updated platform support check for MX GEMM to include gfx950 devices on ROCm.torch/utils/hipify/cuda_to_hip_mappings.py: Added mappings for new MX format attributes and scaling modes.cc @jeffdaily @sunway513 @jithunnair-amd @pruthvistony @ROCmSupport @dllehr-amd @jataylo @hongxiayang @naromero77amd @gujinghui @PenghuiCheng @XiaobingSuper @jianyuh @jgong5 @mingfeima @sanchitintel @ashokei @jingxu10 @min-jean-cho @yanbing-j @Guobing-Chen @Xia-Weiwen @snadampal