KEMBAR78
[Inductor][CPP] Optimize WOQ INT8 wgt dequant in AMX GEMM template by leslie-fang-intel · Pull Request #136630 · pytorch/pytorch · GitHub
Skip to content

Conversation

@leslie-fang-intel
Copy link
Collaborator

@leslie-fang-intel leslie-fang-intel commented Sep 25, 2024

Stack from ghstack (oldest at bottom):

Summary
Optimize the WOQ int8 AMX performance by changing the int8 -> bf16 conversion.
Earlier, 16 int8 elements were being loaded at a time & converted to 16 BF16 elements.
With this change, 32 int8 elements will be loaded at a time, and converted to a cache-line of 32 BF16 elements more efficiently.

Performance before

AUTOTUNE _weight_int8pack_mm(4096x4096, 4096x4096, 4096)
  cpp_packed_gemm_0 38.0439 ms 100.0%
  _weight_int8pack_mm 50.2524 ms 75.7%
SingleProcess AUTOTUNE benchmarking takes 1.1087 seconds and 1.9791 seconds precompiling
AUTOTUNE _weight_int8pack_mm(4096x4096, 11008x4096, 11008)
  cpp_packed_gemm_4 78.2038 ms 100.0%
  _weight_int8pack_mm 119.1962 ms 65.6%
SingleProcess AUTOTUNE benchmarking takes 1.9274 seconds and 1.9949 seconds precompiling
AUTOTUNE _weight_int8pack_mm(4096x11008, 4096x11008, 4096)
  cpp_packed_gemm_6 79.2368 ms 100.0%
  _weight_int8pack_mm 118.3212 ms 67.0%
SingleProcess AUTOTUNE benchmarking takes 1.9200 seconds and 2.0015 seconds precompiling
AUTOTUNE _weight_int8pack_mm(4096x4096, 32000x4096, 32000)
  cpp_packed_gemm_224 225.7201 ms 100.0%
  _weight_int8pack_mm 388.5588 ms 58.1%

Performance after this PR

AUTOTUNE _weight_int8pack_mm(4096x4096, 4096x4096, 4096)
  cpp_packed_gemm_0 11.0086 ms 100.0%
  _weight_int8pack_mm 50.2918 ms 21.9%
SingleProcess AUTOTUNE benchmarking takes 1.0837 seconds and 2.0301 seconds precompiling
AUTOTUNE _weight_int8pack_mm(4096x4096, 11008x4096, 11008)
  cpp_packed_gemm_4 24.3528 ms 100.0%
  _weight_int8pack_mm 119.8492 ms 20.3%
SingleProcess AUTOTUNE benchmarking takes 1.8303 seconds and 1.8195 seconds precompiling
AUTOTUNE _weight_int8pack_mm(4096x11008, 4096x11008, 4096)
  cpp_packed_gemm_6 24.6148 ms 100.0%
  _weight_int8pack_mm 119.1908 ms 20.7%
SingleProcess AUTOTUNE benchmarking takes 1.8315 seconds and 1.8352 seconds precompiling
AUTOTUNE _weight_int8pack_mm(4096x4096, 32000x4096, 32000)
  cpp_packed_gemm_224 78.1369 ms 100.0%
  _weight_int8pack_mm 387.6289 ms 20.2%
SingleProcess AUTOTUNE benchmarking takes 4.5059 seconds and 1.8010 seconds precompiling

cc @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @ipiszy @yf225 @chenyang78 @kadeng @muchulee8 @ColinPeppler @amjames @desertfire @chauhang

[ghstack-poisoned]
leslie-fang-intel added a commit that referenced this pull request Sep 25, 2024
ghstack-source-id: 6830ea4
Pull Request resolved: #136630
@pytorch-bot
Copy link

pytorch-bot bot commented Sep 25, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/136630

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit b8b7923 with merge base failed to retrieve merge base, please contact dev infra:
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

Comment on lines +612 to +614
auto b_int8 = at::vec::Vectorized<int8_t>::loadu(src, static_cast<int64_t>(32));
auto b_bf16 = at::vec::convert<{{input_t}}>(b_int8);
b_bf16.store(dst);
Copy link
Collaborator

@sanchitintel sanchitintel Sep 25, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you! I didn't know this was possible with at::vec, since the conversion happens from int8 -> int32 -> fp32 -> bf16. Looks like at::vec::convert can use multiple intermediate vector registers (two in this case for holding int32 & fp32 values).

Copy link
Collaborator

@sanchitintel sanchitintel Sep 25, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@leslie-fang-intel, do you know if such a change is possible for the AVX512 micro-kernel as well, so that it could load multiple vector registers of B at a time? I mean, by using at::vec, and not intrinsics. Thanks!

Currently, it loads only one FP32 vector register of B at a time

image

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes, I guess so. Probably something like

        auto b_int8 = at::vec::Vectorized<int8_t>::loadu(src, static_cast<int64_t>(16));  // load first 128 bits of 16 X int8
        auto b_fp32 = convert_int8_to_float<int8_t>(b_int8); // CVT to 16 X FP32

Could you take a try if it benefits the AVX Micro GEMM performance?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

But this one is similar to the current implementation.

@sanchitintel sanchitintel changed the title [Inductor][CPP] Optimize WOQ INT8 wgt dequant [Inductor][CPP] Optimize WOQ INT8 wgt dequant in AMX GEMM template Sep 25, 2024
@sanchitintel sanchitintel requested a review from jgong5 September 25, 2024 14:02
@sanchitintel
Copy link
Collaborator

@leslie-fang-intel, could there have been some copy-paste error pertaining to the perf data on M=4096, N=11008, K=4096 and M=4096, N=4096, K=11008?
I'm using an Intel Xeon SP Gen 4 (Platinum 8468H), and the _weight_int8pack_mm ATen kernel doesn't have similar performance for both configs with 48 physical cores (the difference is ~6%). At your end, though, both input shapes exhibited similar performance.

@leslie-fang-intel
Copy link
Collaborator Author

leslie-fang-intel commented Sep 26, 2024

@leslie-fang-intel, could there have been some copy-paste error pertaining to the perf data on M=4096, N=11008, K=4096 and M=4096, N=4096, K=11008? I'm using an Intel Xeon SP Gen 4 (Platinum 8468H), and the _weight_int8pack_mm ATen kernel doesn't have similar performance for both configs with 48 physical cores (the difference is ~6%). At your end, though, both input shapes exhibited similar performance.

I guess it may due to the system difference. Re-run it still with similar performance on my test system.

@leslie-fang-intel
Copy link
Collaborator Author

@pytorchbot merge

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants