KEMBAR78
[CPU][Inductor] Improve performance of A16W8 GEMM template by Xia-Weiwen · Pull Request #161148 · pytorch/pytorch · GitHub
Skip to content

Conversation

@Xia-Weiwen
Copy link
Collaborator

@Xia-Weiwen Xia-Weiwen commented Aug 21, 2025

Summary
This PR improves the performance of A16W8 GEMM template by

  • Removing the config with block_n=48 & block_m=16 as it is not very efficient.
  • Using AMX microkernel when M >= 5 so that we use AMX instead of AVX512 for M=5~31.
  • Converting int8 values to bf16 with intrinsics instead of at::vec::convert as the latter does not have optimized implementation for this case.

We saw up to >10% performance gain in various cases of running Llama-3.1-8b-instruct.

Test plan
Already covered by UT.

cc @jgong5 @mingfeima @XiaobingSuper @sanchitintel @ashokei @jingxu10 @voznesenskym @penguinwu @EikanWang @Guobing-Chen @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @ipiszy @chenyang78 @kadeng @muchulee8 @amjames @chauhang @aakhundov @coconutruben

@pytorch-bot
Copy link

pytorch-bot bot commented Aug 21, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/161148

Note: Links to docs will display an error until the docs builds have been completed.

❗ 1 Active SEVs

There are 1 currently active SEVs. If your PR is affected, please view them below:

✅ No Failures

As of commit 746766f with merge base fa76256 (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@sanchitintel
Copy link
Collaborator

sanchitintel commented Aug 26, 2025

Hi,

Converting int8 values to bf16 with intrinsics instead of at::vec::convert as the latter does not have optimized implementation for this case.

Please advise if it's somehow possible to optimize at::vec::convert instead.

Thank you!

@Xia-Weiwen
Copy link
Collaborator Author

Hi,

Converting int8 values to bf16 with intrinsics instead of at::vec::convert as the latter does not have optimized implementation for this case.

Please advise if it's possible to somehow optimize at::vec::convert instead.

Thank you!

There is not a specialization for int8->bf16 of at::vec::convert right now. Just need to add it.

@Xia-Weiwen
Copy link
Collaborator Author

Hi @CaoE @mingfeima Could you please review? Thanks.

// 4) Convert to f32
__m512 f32 = _mm512_cvtepi32_ps(v32);
// 5) Convert f32 -> bf16 (round-to-nearest-even)
__m256i bf16 = (__m256i)_mm512_cvtneps_pbh(f32);
Copy link
Collaborator

@CaoE CaoE Aug 28, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since intrinsic is used, it is better to check whether the compiler supports these, e.g., _mm512_cvtneps_pbh.
If the compiler does not support, it will choose aten linear and lose the opportunity of using AMX microgemm.
Maybe we can do like #147368.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

BTW, which versions of the compiler support these instructions?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Updated. Thanks.

@Xia-Weiwen Xia-Weiwen requested a review from CaoE August 28, 2025 07:46
@dataclasses.dataclass
class VecAMX(VecAVX512):
_arch_flags = VecAVX512._arch_flags + " -mamx-tile -mamx-bf16 -mamx-int8"
_arch_flags = VecAVX512().build_arch_flags() + " -mamx-tile -mamx-bf16 -mamx-int8"
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you please double check whether VecAVX512().build_arch_flags() will do self.check_build(VecAMX._avx512_bf16_code) ?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unfortunately no. I have updated this part to ensure we get the correct flags. Thanks.

@Xia-Weiwen Xia-Weiwen requested a review from CaoE August 28, 2025 08:52
@Xia-Weiwen Xia-Weiwen marked this pull request as ready for review August 29, 2025 01:32
@CaoE CaoE added the ciflow/trunk Trigger trunk jobs on your pull request label Aug 29, 2025
@Xia-Weiwen Xia-Weiwen requested a review from jansel August 29, 2025 05:24
@Xia-Weiwen
Copy link
Collaborator Author

@pytorchbot merge

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

markc-614 pushed a commit to markc-614/pytorch that referenced this pull request Sep 17, 2025
…61148)

**Summary**
This PR improves the performance of A16W8 GEMM template by
- Removing the config with block_n=48 & block_m=16 as it is not very efficient.
- Using AMX microkernel when M >= 5 so that we use AMX instead of AVX512 for M=5~31.
- Converting int8 values to bf16 with intrinsics instead of `at::vec::convert` as the latter does not have optimized implementation for this case.

We saw up to >10% performance gain in various cases of running Llama-3.1-8b-instruct.

**Test plan**
Already covered by UT.

Pull Request resolved: pytorch#161148
Approved by: https://github.com/CaoE, https://github.com/jansel
mansiag05 pushed a commit to mansiag05/pytorch that referenced this pull request Sep 22, 2025
…61148)

**Summary**
This PR improves the performance of A16W8 GEMM template by
- Removing the config with block_n=48 & block_m=16 as it is not very efficient.
- Using AMX microkernel when M >= 5 so that we use AMX instead of AVX512 for M=5~31.
- Converting int8 values to bf16 with intrinsics instead of `at::vec::convert` as the latter does not have optimized implementation for this case.

We saw up to >10% performance gain in various cases of running Llama-3.1-8b-instruct.

**Test plan**
Already covered by UT.

Pull Request resolved: pytorch#161148
Approved by: https://github.com/CaoE, https://github.com/jansel
@zou3519
Copy link
Contributor

zou3519 commented Oct 8, 2025

@Xia-Weiwen @CaoE @jansel does this PR improve CPU-only llama3 performance, or does it also affect llama3 running on CUDA? We're seeing something weird where this PR appears to affect llama4 performance on CUDA (maybe there are some cpu pieces in there, I"m not sure)

atalman added a commit to atalman/pytorch that referenced this pull request Oct 8, 2025
atalman added a commit to atalman/pytorch that referenced this pull request Oct 8, 2025
@Xia-Weiwen
Copy link
Collaborator Author

@zou3519 It should not affect cuda. It's for CPU only. It has no effect unless you run A16W8 (bf16-int8) GEMMs on CPU with AMX.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ciflow/inductor ciflow/trunk Trigger trunk jobs on your pull request intel This tag is for PR from Intel Merged module: inductor open source topic: not user facing topic category

Projects

None yet

Development

Successfully merging this pull request may close these issues.

7 participants