-
Notifications
You must be signed in to change notification settings - Fork 25.7k
[CPU][Inductor] Improve performance of A16W8 GEMM template #161148
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/161148
Note: Links to docs will display an error until the docs builds have been completed. ❗ 1 Active SEVsThere are 1 currently active SEVs. If your PR is affected, please view them below: ✅ No FailuresAs of commit 746766f with merge base fa76256 ( This comment was automatically generated by Dr. CI and updates every 15 minutes. |
|
Hi,
Please advise if it's somehow possible to optimize Thank you! |
There is not a specialization for int8->bf16 of |
|
Hi @CaoE @mingfeima Could you please review? Thanks. |
| // 4) Convert to f32 | ||
| __m512 f32 = _mm512_cvtepi32_ps(v32); | ||
| // 5) Convert f32 -> bf16 (round-to-nearest-even) | ||
| __m256i bf16 = (__m256i)_mm512_cvtneps_pbh(f32); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Since intrinsic is used, it is better to check whether the compiler supports these, e.g., _mm512_cvtneps_pbh.
If the compiler does not support, it will choose aten linear and lose the opportunity of using AMX microgemm.
Maybe we can do like #147368.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
BTW, which versions of the compiler support these instructions?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Updated. Thanks.
torch/_inductor/cpu_vec_isa.py
Outdated
| @dataclasses.dataclass | ||
| class VecAMX(VecAVX512): | ||
| _arch_flags = VecAVX512._arch_flags + " -mamx-tile -mamx-bf16 -mamx-int8" | ||
| _arch_flags = VecAVX512().build_arch_flags() + " -mamx-tile -mamx-bf16 -mamx-int8" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could you please double check whether VecAVX512().build_arch_flags() will do self.check_build(VecAMX._avx512_bf16_code) ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Unfortunately no. I have updated this part to ensure we get the correct flags. Thanks.
|
@pytorchbot merge |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
…61148) **Summary** This PR improves the performance of A16W8 GEMM template by - Removing the config with block_n=48 & block_m=16 as it is not very efficient. - Using AMX microkernel when M >= 5 so that we use AMX instead of AVX512 for M=5~31. - Converting int8 values to bf16 with intrinsics instead of `at::vec::convert` as the latter does not have optimized implementation for this case. We saw up to >10% performance gain in various cases of running Llama-3.1-8b-instruct. **Test plan** Already covered by UT. Pull Request resolved: pytorch#161148 Approved by: https://github.com/CaoE, https://github.com/jansel
…61148) **Summary** This PR improves the performance of A16W8 GEMM template by - Removing the config with block_n=48 & block_m=16 as it is not very efficient. - Using AMX microkernel when M >= 5 so that we use AMX instead of AVX512 for M=5~31. - Converting int8 values to bf16 with intrinsics instead of `at::vec::convert` as the latter does not have optimized implementation for this case. We saw up to >10% performance gain in various cases of running Llama-3.1-8b-instruct. **Test plan** Already covered by UT. Pull Request resolved: pytorch#161148 Approved by: https://github.com/CaoE, https://github.com/jansel
|
@Xia-Weiwen @CaoE @jansel does this PR improve CPU-only llama3 performance, or does it also affect llama3 running on CUDA? We're seeing something weird where this PR appears to affect llama4 performance on CUDA (maybe there are some cpu pieces in there, I"m not sure) |
…ytorch#161148)" This reverts commit 75bc23c.
…ytorch#161148)" This reverts commit 75bc23c.
|
@zou3519 It should not affect cuda. It's for CPU only. It has no effect unless you run A16W8 (bf16-int8) GEMMs on CPU with AMX. |
Summary
This PR improves the performance of A16W8 GEMM template by
at::vec::convertas the latter does not have optimized implementation for this case.We saw up to >10% performance gain in various cases of running Llama-3.1-8b-instruct.
Test plan
Already covered by UT.
cc @jgong5 @mingfeima @XiaobingSuper @sanchitintel @ashokei @jingxu10 @voznesenskym @penguinwu @EikanWang @Guobing-Chen @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @ipiszy @chenyang78 @kadeng @muchulee8 @amjames @chauhang @aakhundov @coconutruben