KEMBAR78
AMD/ROCm OCP Micro-scaling Format (mx-fp8/mx-fp4) Support by petrex · Pull Request #151360 · pytorch/pytorch · GitHub
Skip to content

Conversation

@petrex
Copy link
Contributor

@petrex petrex commented Apr 15, 2025

  • This pull request introduces support for the OCP Micro-scaling (MX) format, with a focus on compatibility with AMD ROCm 7.0 and the gfx950 architecture.

    This PR also establishes the foundation for enabling MX-FPX features in TorchAO on the AMD platform.

  • Validation (ROCm 7.0 + gfx950 required):

    111 relevant tests passing.

    PYTORCH_TEST_WITH_ROCM=1 python test/test_matmul_cuda.py -k test_blockwise -v

    Co-author: @jagadish-amd — Thank you for the efforts leading validation on gfx950 with ROCm 7.0.


This pull request introduces support for new scalar types and scaling methods, particularly for ROCm 7.0 and gfx950, and refines testing for these features. Key changes include adding constraints for matrix dimensions, enabling block-wise scaling, and updating tests to accommodate new data types.

Support for new scalar types and scaling methods:

  • aten/src/ATen/cuda/CUDABlas.cpp: Added constraints for matrix dimensions when using Float8_e8m0fnu with block-wise scaling, ensuring dimensions are multiples of 32. Updated compatibility checks to support ROCm 7.0 for Float8_e8m0fnu and Float8_e4m3fn. [1] [2]

  • aten/src/ATen/native/cuda/Blas.cpp: Introduced block-wise scaling for Float8_e8m0fnu, with checks for ROCm 7.0 and GPU architecture gfx950. Added validation for supported scalar types and matrix dimensions. [1] [2]

Updates to scalar type mappings:

Enhancements to testing(@jagadish-amd):

  • test/test_matmul_cuda.py: Updated tests to include new scalar types (Float4_e2m1fn_x2) and recipes (mxfp4). Added logic to handle different scaling recipes and validate compatibility with ROCm and CUDA versions. [1] [2] F592e669L1353R1472)

These changes improve compatibility with newer hardware and software versions, enhance functionality for matrix operations, and ensure robust testing for the added features.

cc @jeffdaily @sunway513 @jithunnair-amd @pruthvistony @ROCmSupport @dllehr-amd @jataylo @hongxiayang @naromero77amd @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @chenyang78 @kadeng @chauhang @amjames @Lucaskabela

@pytorch-bot
Copy link

pytorch-bot bot commented Apr 15, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/151360

Note: Links to docs will display an error until the docs builds have been completed.

❌ 7 New Failures

As of commit 73022c3 with merge base d46768d (image):

NEW FAILURES - The following jobs have failed:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@pytorch-bot pytorch-bot bot added the module: rocm AMD GPU support for Pytorch label Apr 15, 2025
@mikaylagawarecki mikaylagawarecki added the triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module label Apr 16, 2025
@pruthvistony
Copy link
Collaborator

@petrex ,
Can we enable the UTs which test FP4 for _scaled_gemm()

@pruthvistony pruthvistony marked this pull request as draft May 13, 2025 21:27
@petrex
Copy link
Contributor Author

petrex commented May 13, 2025

@petrex , Can we enable the UTs which test FP4 for _scaled_gemm()

fp4 support requires gfx950. any plan to deploy that in CI?

pruthvistony pushed a commit to ROCm/pytorch that referenced this pull request May 29, 2025
This PR enables mx data type support on ROCm.

Current test stats (accounting only blockwise scale tests)
PYTORCH_TEST_WITH_ROCM=1 python test/test_matmul_cuda.py -k
test_blockwise -v
Ran 452 tests in 17.470s
FAILED (failures=2, errors=2, skipped=337)
_111_ test pass

**fp8 mx data type sample test case.**
PYTORCH_TEST_WITH_ROCM=1 python test/test_matmul_cuda.py
TestFP8MatmulCudaCUDA.test_blockwise_mxfp8_nvfp4_numerics_test_case_name_a_eye_b_eye_fast_accum_False_128_128_128_recipe_mxfp8_cuda
-v

**HipblasLT log** hipblaslt-bench --api_method c -m 128 -n 128 -k 128
--lda 128 --ldb 128 --ldc 128 --ldd 128 --stride_a 0 --stride_b 0
--stride_c 0 --stride_d 0 --alpha 1 --beta 0 --transA T --transB N
--batch_count 1 **--scaleA 3 --scaleB 3 --a_type f8_r --b_type f8_r
-**-c_type bf16_r --d_type bf16_r --compute_type f32_r --algo_method
index --solution_index -2147220478 --rotating 0 --cold_iters 0 --iters 0

**fp4 mx data type sample test case.** 
PYTORCH_TEST_WITH_ROCM=1 python test/test_matmul_cuda.py
TestFP8MatmulCudaCUDA.test_blockwise_mxfp8_nvfp4_numerics_test_case_name_a_eye_b_eye_fast_accum_False_128_128_128_recipe_nvfp4_cuda
-v
**HipblasLT log** hipblaslt-bench --api_method c -m 128 -n 128 -k 128
--lda 128 --ldb 128 --ldc 128 --ldd 128 --stride_a 0 --stride_b 0
--stride_c 0 --stride_d 0 --alpha 1 --beta 0 --transA T --transB N
--batch_count 1 **--scaleA 3 --scaleB** 3 **--a_type f4_r --b_type
f4_r** --c_type bf16_r --d_type bf16_r --compute_type f32_r
--algo_method index --solution_index -2147220478 --rotating 0
--cold_iters 0 --iters 0


Commits:

1. **ROCm MX-FP8 Gemm**  (PR from @petrex )
Ported the patch from pytorch#147553
Commented few lines to avoid compilation error. (check for todo
comments)

2. **Refine _platform_supports_mx_gemm check**

3. **For mx fp8, A and B need not be kFloat8_e8m0fnu type**

4.  **Add fp4 support** (PR from @petrex )
Ported the patch from pytorch#151360
Added fp4 type in aten/src/ATen/cuda/CUDADataType.h
Added more mappings in aten/src/ATen/cuda/CUDADataType.h
Use e8m0 scaling dtype for fp4 test case for ROCm in
test/test_matmul_cuda.py

5.  **test_matmul: change code to correctly skip**
6. **test_matmul: skip if nv format**
skip tests if Matrix dimensions must be multiples of 32.
skip convert to swizzled format
7.  **add fp4 support for data_to_mx_scale**
8. **test_matmul: Add mxfp4 test case**

---------

Signed-off-by: Jagadish Krishnamoorthy <jagadish.krishnamoorthy@amd.com>
@petrex petrex changed the title ROCm mx-fp4 Support ROCm mx-fp4/mx-fp8 Support Jun 3, 2025
@petrex petrex changed the title ROCm mx-fp4/mx-fp8 Support ROCm OCP Micro-scaling Format (mx-fp8/mx-fp4) Support Jun 3, 2025
@petrex petrex marked this pull request as ready for review June 3, 2025 19:50
This was referenced Jun 3, 2025
@drisspg drisspg self-requested a review June 20, 2025 19:42
A_scale = to_blocked(A_scale)
B_scale = to_blocked(B_scale)
if not torch.version.hip:
A_scale = to_blocked(A_scale)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The row,major layout of the scales is exactly whats expected for rocm? OOC Is this maximal perf?

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, this is what is expected from hipBLASLt within ROCm.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also curious are there aligment constraints for the scales? It seems by the fact that all the parametrized test are passing? there is not?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As far as I recall, all applicable tests on ROCm 7.0 are passing. There may be additional tests that aren’t currently supported in our setup(gfx950 && ROCm 7.0).

@drisspg
Copy link
Contributor

drisspg commented Jun 20, 2025

Is this code exercised in CI?

Update:
just saw

fp4 support requires gfx950. any plan to deploy that in CI?

@drisspg drisspg added the release notes: rocm mandatorylabel label Jun 20, 2025
@jagadish-amd
Copy link
Contributor

Is this code exercised in CI?

Update: just saw

fp4 support requires gfx950. any plan to deploy that in CI?

right, not yet on CI (requires gfx950)

@petrex
Copy link
Contributor Author

petrex commented Jul 14, 2025

refer to #158221

@pytorch-bot pytorch-bot bot added the ciflow/rocm Trigger "default" config CI on ROCm label Jul 17, 2025
@pytorch-bot pytorch-bot bot removed the ciflow/rocm Trigger "default" config CI on ROCm label Jul 30, 2025
…elated functions. Update error messages to reflect ROCm 6.5 compatibility. Add HIP data type mapping for Float4_e2m1fn_x2. Ensure proper version checks for ROCm in CUDA operations.
@pytorch-bot
Copy link

pytorch-bot bot commented Aug 15, 2025

You are not authorized to force merges to this repository. Please use the regular @pytorchmergebot merge command instead

@jeffdaily
Copy link
Collaborator

@pytorchbot merge -f 'tests failures irrelevant(CPU based)'

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged immediately since you used the force (-f) flag, bypassing any CI checks (ETA: 1-5 minutes). Please use -f as last resort and instead consider -i/--ignore-current to continue the merge ignoring current failures. This will allow currently pending tests to finish and report signal before the merge.

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

@weifengpy
Copy link
Contributor

this seems to be causing torchtitan/torchao failures for H100, see pytorch/ao#2843

@drisspg
Copy link
Contributor

drisspg commented Aug 22, 2025

@pytorchbot revert

@pytorch-bot
Copy link

pytorch-bot bot commented Aug 22, 2025

❌ 🤖 pytorchbot command failed:

@pytorchbot revert: error: the following arguments are required: -m/--message, -c/--classification

usage: @pytorchbot revert -m MESSAGE -c
                          {nosignal,ignoredsignal,landrace,weird,ghfirst}

Try @pytorchbot --help for more info.

@drisspg
Copy link
Contributor

drisspg commented Aug 22, 2025

@pytorchbot revert "Broke fp8 rowwise on cuda 12.9 + pytorch/ao#2843"

@pytorch-bot
Copy link

pytorch-bot bot commented Aug 22, 2025

❌ 🤖 pytorchbot command failed:

@pytorchbot revert: error: the following arguments are required: -m/--message, -c/--classification

usage: @pytorchbot revert -m MESSAGE -c
                          {nosignal,ignoredsignal,landrace,weird,ghfirst}

Try @pytorchbot --help for more info.

@drisspg
Copy link
Contributor

drisspg commented Aug 22, 2025

@pytorchbot revert -m "Broke fp8 rowwise on cuda 12.9 + pytorch/ao#2843" -c landrace

@pytorchmergebot
Copy link
Collaborator

@pytorchbot successfully started a revert job. Check the current status here.
Questions? Feedback? Please reach out to the PyTorch DevX Team

@pytorchmergebot
Copy link
Collaborator

Reverting PR 151360 failed

Reason: Command git -C /home/runner/work/pytorch/pytorch revert --no-edit e389a08dcd4f703a113edd3b252fe25572a8cea5 returned non-zero exit code 1

Auto-merging aten/src/ATen/native/cuda/Blas.cpp
Auto-merging test/test_matmul_cuda.py
CONFLICT (content): Merge conflict in test/test_matmul_cuda.py
Auto-merging torch/testing/_internal/common_cuda.py
error: could not revert e389a08dcd4... AMD/ROCm OCP Micro-scaling Format (mx-fp8/mx-fp4) Support (#151360)
hint: After resolving the conflicts, mark them with
hint: "git add/rm <pathspec>", then run
hint: "git revert --continue".
hint: You can instead skip this commit with "git revert --skip".
hint: To abort and get back to the state before "git revert",
hint: run "git revert --abort".
hint: Disable this message with "git config set advice.mergeConflict false"
Details for Dev Infra team Raised by workflow job

@drisspg
Copy link
Contributor

drisspg commented Aug 22, 2025

hmm no clean revert, can you forward fix this
@petrex
https://github.com/petrex/pytorch/blob/73022c3d3262c9adb351a4a8c739b98ad3c5ff94/aten/src/ATen/cuda/CUDABlas.cpp#L1868-L1875

it should be good on 12.9 +

can-gaa-hou pushed a commit to can-gaa-hou/pytorch that referenced this pull request Aug 22, 2025
…1360)

- This pull request introduces support for the [OCP Micro-scaling (MX) format](https://www.opencompute.org/documents/ocp-microscaling-formats-mx-v1-0-spec-final-pdf), with a focus on compatibility with AMD **ROCm 7.0** and the **gfx950** architecture.

  This PR also establishes the foundation for enabling MX-FPX features in [TorchAO](pytorch/ao#2229) on the AMD platform.

- Validation (**ROCm 7.0** + **gfx950** required):

  `111 relevant tests passing.`

  > PYTORCH_TEST_WITH_ROCM=1 python test/test_matmul_cuda.py -k test_blockwise -v

  Co-author: @jagadish-amd —  Thank you for the efforts leading validation on gfx950 with ROCm 7.0.

-----------------------------------

This pull request introduces support for new scalar types and scaling methods, particularly for ROCm 7.0 and gfx950, and refines testing for these features. Key changes include adding constraints for matrix dimensions, enabling block-wise scaling, and updating tests to accommodate new data types.

### Support for new scalar types and scaling methods:
* [`aten/src/ATen/cuda/CUDABlas.cpp`](diffhunk://#diff-74fcb26047c1df4024105d36ce22a36b77cf8cc93c28631d743e639b3d6066aeR1876-R1885): Added constraints for matrix dimensions when using `Float8_e8m0fnu` with block-wise scaling, ensuring dimensions are multiples of 32. Updated compatibility checks to support ROCm 7.0 for `Float8_e8m0fnu` and `Float8_e4m3fn`. [[1]](diffhunk://#diff-74fcb26047c1df4024105d36ce22a36b77cf8cc93c28631d743e639b3d6066aeR1876-R1885) [[2]](diffhunk://#diff-74fcb26047c1df4024105d36ce22a36b77cf8cc93c28631d743e639b3d6066aeL1913-R1934)

* [`aten/src/ATen/native/cuda/Blas.cpp`](diffhunk://#diff-e8a569efee1e650172f120a0fdcda024fe3e4703a4ee3336425c8f685af6b3abR1276-R1290): Introduced block-wise scaling for `Float8_e8m0fnu`, with checks for ROCm 7.0 and GPU architecture `gfx950`. Added validation for supported scalar types and matrix dimensions. [[1]](diffhunk://#diff-e8a569efee1e650172f120a0fdcda024fe3e4703a4ee3336425c8f685af6b3abR1276-R1290) [[2]](diffhunk://#diff-e8a569efee1e650172f120a0fdcda024fe3e4703a4ee3336425c8f685af6b3abR1349-R1364)

### Updates to scalar type mappings:
* [`aten/src/ATen/cuda/CUDADataType.h`](diffhunk://#diff-9188bb13b1a49f459141f5f9b875593d1c5ce2beb5ad711fdbaf5bc7089ec015L93-R93): Extended scalar type mappings to support `Float4_e2m1fn_x2` for ROCm 7.0.

* [`aten/src/ATen/cuda/tunable/GemmHipblaslt.h`](diffhunk://#diff-bfa1a3b5d4bef1892bf50338775f3b0fd8cd31fc1868148f3968b98aefb68e3fR88-R96): Added a constexpr mapping for `Float4_e2m1fn_x2` based on ROCm version.

### Enhancements to testing(@jagadish-amd):
* [`test/test_matmul_cuda.py`](diffhunk://#diff-3f31c52b48cfddf8f4617d809f7695b2e4a1c78656f8c4b5143a4b45d01fcf23R765-R766): Updated tests to include new scalar types (`Float4_e2m1fn_x2`) and recipes (`mxfp4`). Added logic to handle different scaling recipes and validate compatibility with ROCm and CUDA versions. [[1]](diffhunk://#diff-3f31c52b48cfddf8f4617d809f7695b2e4a1c78656f8c4b5143a4b45d01fcf23R765-R766) [[2]](diffhunk://#diff-3f31c52b48cfddf8f4617d809f7695b2e4a1c78656f8c4b5143a4b45d01fcf23L1331-R1356) F592e669L1353R1472)

These changes improve compatibility with newer hardware and software versions, enhance functionality for matrix operations, and ensure robust testing for the added features.

Pull Request resolved: pytorch#151360
Approved by: https://github.com/drisspg, https://github.com/malfet
matmulDescB = HIPBLASLT_MATMUL_DESC_B_SCALE_POINTER_VEC_EXT;
}
#endif // if defined(USE_ROCM) && !defined(HIPBLASLT_OUTER_VEC) && defined(HIPBLASLT_VEC_EXT)
else if (mat1_scale_dtype == kFloat8_e8m0fnu && mat2_scale_dtype == kFloat8_e8m0fnu) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This check need not be inside #if defined(USE_ROCM) && !defined(HIPBLASLT_OUTER_VEC) && defined(HIPBLASLT_VEC_EXT)

jagadish-amd added a commit to jagadish-amd/pytorch that referenced this pull request Sep 16, 2025
PR pytorch#151360 added mx fp8 and fp4 support on ROCm.
However on recent upstream, scaling function in Blas.cpp along
with test_matmul_cuda changes trigerred failures.
This patch corrects is_blockwise_1x32_scaling function code
and fixes minor bug in test_matmul_cuda.

Testing result on gfx950 w/ ROCm7.0

PYTORCH_TEST_WITH_ROCM=1 python test/test_matmul_cuda.py -k test_blockwise -v
Ran 452 tests in 22.698s
OK passed 111
This is same as before. (when PR 151360 was merged)

Signed-off-by: Jagadish Krishnamoorthy <jagadish.krishnamoorthy@amd.com>
markc-614 pushed a commit to markc-614/pytorch that referenced this pull request Sep 17, 2025
…1360)

- This pull request introduces support for the [OCP Micro-scaling (MX) format](https://www.opencompute.org/documents/ocp-microscaling-formats-mx-v1-0-spec-final-pdf), with a focus on compatibility with AMD **ROCm 7.0** and the **gfx950** architecture.

  This PR also establishes the foundation for enabling MX-FPX features in [TorchAO](pytorch/ao#2229) on the AMD platform.

- Validation (**ROCm 7.0** + **gfx950** required):

  `111 relevant tests passing.`

  > PYTORCH_TEST_WITH_ROCM=1 python test/test_matmul_cuda.py -k test_blockwise -v

  Co-author: @jagadish-amd —  Thank you for the efforts leading validation on gfx950 with ROCm 7.0.

-----------------------------------

This pull request introduces support for new scalar types and scaling methods, particularly for ROCm 7.0 and gfx950, and refines testing for these features. Key changes include adding constraints for matrix dimensions, enabling block-wise scaling, and updating tests to accommodate new data types.

### Support for new scalar types and scaling methods:
* [`aten/src/ATen/cuda/CUDABlas.cpp`](diffhunk://#diff-74fcb26047c1df4024105d36ce22a36b77cf8cc93c28631d743e639b3d6066aeR1876-R1885): Added constraints for matrix dimensions when using `Float8_e8m0fnu` with block-wise scaling, ensuring dimensions are multiples of 32. Updated compatibility checks to support ROCm 7.0 for `Float8_e8m0fnu` and `Float8_e4m3fn`. [[1]](diffhunk://#diff-74fcb26047c1df4024105d36ce22a36b77cf8cc93c28631d743e639b3d6066aeR1876-R1885) [[2]](diffhunk://#diff-74fcb26047c1df4024105d36ce22a36b77cf8cc93c28631d743e639b3d6066aeL1913-R1934)

* [`aten/src/ATen/native/cuda/Blas.cpp`](diffhunk://#diff-e8a569efee1e650172f120a0fdcda024fe3e4703a4ee3336425c8f685af6b3abR1276-R1290): Introduced block-wise scaling for `Float8_e8m0fnu`, with checks for ROCm 7.0 and GPU architecture `gfx950`. Added validation for supported scalar types and matrix dimensions. [[1]](diffhunk://#diff-e8a569efee1e650172f120a0fdcda024fe3e4703a4ee3336425c8f685af6b3abR1276-R1290) [[2]](diffhunk://#diff-e8a569efee1e650172f120a0fdcda024fe3e4703a4ee3336425c8f685af6b3abR1349-R1364)

### Updates to scalar type mappings:
* [`aten/src/ATen/cuda/CUDADataType.h`](diffhunk://#diff-9188bb13b1a49f459141f5f9b875593d1c5ce2beb5ad711fdbaf5bc7089ec015L93-R93): Extended scalar type mappings to support `Float4_e2m1fn_x2` for ROCm 7.0.

* [`aten/src/ATen/cuda/tunable/GemmHipblaslt.h`](diffhunk://#diff-bfa1a3b5d4bef1892bf50338775f3b0fd8cd31fc1868148f3968b98aefb68e3fR88-R96): Added a constexpr mapping for `Float4_e2m1fn_x2` based on ROCm version.

### Enhancements to testing(@jagadish-amd):
* [`test/test_matmul_cuda.py`](diffhunk://#diff-3f31c52b48cfddf8f4617d809f7695b2e4a1c78656f8c4b5143a4b45d01fcf23R765-R766): Updated tests to include new scalar types (`Float4_e2m1fn_x2`) and recipes (`mxfp4`). Added logic to handle different scaling recipes and validate compatibility with ROCm and CUDA versions. [[1]](diffhunk://#diff-3f31c52b48cfddf8f4617d809f7695b2e4a1c78656f8c4b5143a4b45d01fcf23R765-R766) [[2]](diffhunk://#diff-3f31c52b48cfddf8f4617d809f7695b2e4a1c78656f8c4b5143a4b45d01fcf23L1331-R1356) F592e669L1353R1472)

These changes improve compatibility with newer hardware and software versions, enhance functionality for matrix operations, and ensure robust testing for the added features.

Pull Request resolved: pytorch#151360
Approved by: https://github.com/drisspg, https://github.com/malfet
pytorchmergebot pushed a commit that referenced this pull request Sep 19, 2025
PR #151360 added mx fp8 and fp4 support on ROCm.
1. However, on recent upstream, scaling function in Blas.cpp along with test_matmul_cuda changes triggered failures.
This patch corrects is_blockwise_1x32_scaling function code.

2. Fixes the m, n, k dimensions for ROCm mx case.

3.  Modify FP4E2M1FN_LARGEST_POW2 (largest power of 2 representable in `torch.float4_e2m1fn_x2`) to 2.
This resulted in higher SQNR value for mx fp4 test.

Testing result on gfx950 w/ ROCm7.0

PYTORCH_TEST_WITH_ROCM=1 python test/test_matmul_cuda.py -k test_blockwise -v Ran 452 tests in 22.698s
OK passed 111
This is same as before. (when PR 151360 was merged)

Pull Request resolved: #163127
Approved by: https://github.com/jeffdaily

Co-authored-by: Jeff Daily <jeff.daily@amd.com>
mansiag05 pushed a commit to mansiag05/pytorch that referenced this pull request Sep 22, 2025
…ch#163127)

PR pytorch#151360 added mx fp8 and fp4 support on ROCm.
1. However, on recent upstream, scaling function in Blas.cpp along with test_matmul_cuda changes triggered failures.
This patch corrects is_blockwise_1x32_scaling function code.

2. Fixes the m, n, k dimensions for ROCm mx case.

3.  Modify FP4E2M1FN_LARGEST_POW2 (largest power of 2 representable in `torch.float4_e2m1fn_x2`) to 2.
This resulted in higher SQNR value for mx fp4 test.

Testing result on gfx950 w/ ROCm7.0

PYTORCH_TEST_WITH_ROCM=1 python test/test_matmul_cuda.py -k test_blockwise -v Ran 452 tests in 22.698s
OK passed 111
This is same as before. (when PR 151360 was merged)

Pull Request resolved: pytorch#163127
Approved by: https://github.com/jeffdaily

Co-authored-by: Jeff Daily <jeff.daily@amd.com>
cleonard530 pushed a commit to cleonard530/pytorch that referenced this pull request Sep 22, 2025
…ch#163127)

PR pytorch#151360 added mx fp8 and fp4 support on ROCm.
1. However, on recent upstream, scaling function in Blas.cpp along with test_matmul_cuda changes triggered failures.
This patch corrects is_blockwise_1x32_scaling function code.

2. Fixes the m, n, k dimensions for ROCm mx case.

3.  Modify FP4E2M1FN_LARGEST_POW2 (largest power of 2 representable in `torch.float4_e2m1fn_x2`) to 2.
This resulted in higher SQNR value for mx fp4 test.

Testing result on gfx950 w/ ROCm7.0

PYTORCH_TEST_WITH_ROCM=1 python test/test_matmul_cuda.py -k test_blockwise -v Ran 452 tests in 22.698s
OK passed 111
This is same as before. (when PR 151360 was merged)

Pull Request resolved: pytorch#163127
Approved by: https://github.com/jeffdaily

Co-authored-by: Jeff Daily <jeff.daily@amd.com>
dsashidh pushed a commit to dsashidh/pytorch that referenced this pull request Sep 26, 2025
…ch#163127)

PR pytorch#151360 added mx fp8 and fp4 support on ROCm.
1. However, on recent upstream, scaling function in Blas.cpp along with test_matmul_cuda changes triggered failures.
This patch corrects is_blockwise_1x32_scaling function code.

2. Fixes the m, n, k dimensions for ROCm mx case.

3.  Modify FP4E2M1FN_LARGEST_POW2 (largest power of 2 representable in `torch.float4_e2m1fn_x2`) to 2.
This resulted in higher SQNR value for mx fp4 test.

Testing result on gfx950 w/ ROCm7.0

PYTORCH_TEST_WITH_ROCM=1 python test/test_matmul_cuda.py -k test_blockwise -v Ran 452 tests in 22.698s
OK passed 111
This is same as before. (when PR 151360 was merged)

Pull Request resolved: pytorch#163127
Approved by: https://github.com/jeffdaily

Co-authored-by: Jeff Daily <jeff.daily@amd.com>
pytorchmergebot pushed a commit that referenced this pull request Oct 9, 2025
Useful to have PR testing for PRs such as #151360

Pull Request resolved: #160215
Approved by: https://github.com/malfet, https://github.com/atalman

Co-authored-by: Jeff Daily <jeff.daily@amd.com>
Chao1Han pushed a commit to Chao1Han/pytorch that referenced this pull request Oct 21, 2025
…h#160215)

Useful to have PR testing for PRs such as pytorch#151360

Pull Request resolved: pytorch#160215
Approved by: https://github.com/malfet, https://github.com/atalman

Co-authored-by: Jeff Daily <jeff.daily@amd.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ciflow/inductor-rocm Trigger "inductor" config CI on ROCm ciflow/periodic-rocm-mi300 Trigger "distributed" config CI on ROCm MI300 ciflow/rocm Trigger "default" config CI on ROCm ciflow/rocm-mi300 Trigger "default" config CI on ROCm MI300 ciflow/trunk Trigger trunk jobs on your pull request Merged module: dynamo module: rocm AMD GPU support for Pytorch open source release notes: rocm mandatorylabel triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Projects

None yet

Development

Successfully merging this pull request may close these issues.