KEMBAR78
[Inductor][CPP] Cache weight tiles in L1D for AMX int8 WoQ GEMM by sanchitintel · Pull Request #136688 · pytorch/pytorch · GitHub
Skip to content

Conversation

@sanchitintel
Copy link
Collaborator

@sanchitintel sanchitintel commented Sep 25, 2024

Summary

The AMX ISA based GEMM micro-kernel template for int8 weight-only quantization (BF16 activation, int8 weights) should cache dequantized weights (int8 -> int32 -> fp32 -> bf16) so that they would not have to be dequantized again in subsequent calls to the inner-kernel that uses the same weights.

This change leverages the fact that even for BF16 x BF16 GEMM template, cache-blocking ensures that Nr * Kc weight elements are cached in L1D cache (more info here). Here, Nr is the register blocking size for N dimension (at the granularity of the GEMM micro-kernel, it's currently also the cache blocking size for N dimension, although that may change in the future), and Kc is the cache blocking size for K dimension.

The figure below is from the document linked above -

image

Performance data

Collected on 48 physical cores of one socket of Intel Xeon Platinum 8468H (Xeon SP 4th gen). Intel OpenMP & tcmalloc were preloaded.

M N K Latency with ATen _weight_int8pack_mm Latency with codegened templated GEMM (current main branch) Latency with codegened templated GEMM (this PR)
4096 4096 4096 45.844 ms 9.322 ms 5.2181 ms
4096 11008 4096 127.618 ms 24.6258 ms 13.6046 ms
4096 4096 11008 121.953 ms 25.4692 ms 10.2669 ms
4096 32000 4096 478.450 ms 75.3942 ms 48.21 ms

cc @jgong5 @mingfeima @XiaobingSuper @ashokei @jingxu10 @voznesenskym @penguinwu @EikanWang @Guobing-Chen @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @ipiszy @yf225 @chenyang78 @kadeng @muchulee8 @ColinPeppler @amjames @desertfire @chauhang @aakhundov @XilunWu @H-Huang @awgu @kwen2501 @wanchaol @fegin @fduwjj @wz337 @wconstab @d4l3k @c-p-i-o @rec

@pytorch-bot
Copy link

pytorch-bot bot commented Sep 25, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/136688

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit 3b3f56e with merge base e6e140c (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@sanchitintel sanchitintel force-pushed the cache_b_tiles_int8_woq_gemm branch from d9d6a90 to 1b3c163 Compare September 25, 2024 22:38
@sanchitintel sanchitintel added topic: performance topic category intel This tag is for PR from Intel release notes: intel release notes category labels Sep 25, 2024
@sanchitintel sanchitintel changed the title Cache B tiles in L1D for AMX int8 WoQ micro-kernel Cache weight tiles in L1D for int8 WoQ GEMM AMX micro-kernel Sep 25, 2024
@sanchitintel sanchitintel changed the title Cache weight tiles in L1D for int8 WoQ GEMM AMX micro-kernel [Inductor][CPP] Cache weight tiles in L1D for AMX int8 WoQ GEMM Sep 25, 2024
Copy link
Collaborator

@leslie-fang-intel leslie-fang-intel left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: do we need to add assert here

for (int64_t n = 0; n < N; n += {{block_n}}) {
that for WOQ int8, N == block_n

@sanchitintel
Copy link
Collaborator Author

sanchitintel commented Sep 26, 2024

nit: do we need to add assert here

for (int64_t n = 0; n < N; n += {{block_n}}) {

that for WOQ int8, N == block_n

In the CPP template, it's being done for all dtypes - while calling a micro-kernel, cache blocking size for N dimension is equal to the register blocking size for N. I added a comment in the code, so that we may make necessary changes (buffer allocation & index computation) if that'd cease to be the case.

Thanks!

@sanchitintel
Copy link
Collaborator Author

@pytorchbot rebase

@pytorchmergebot
Copy link
Collaborator

@pytorchbot started a rebase job onto refs/remotes/origin/viable/strict. Check the current status here

@pytorchmergebot
Copy link
Collaborator

Successfully rebased cache_b_tiles_int8_woq_gemm onto refs/remotes/origin/viable/strict, please pull locally before adding more changes (for example, via git checkout cache_b_tiles_int8_woq_gemm && git pull --rebase)

@pytorchmergebot pytorchmergebot force-pushed the cache_b_tiles_int8_woq_gemm branch from 83a7699 to d0e7815 Compare September 26, 2024 22:16
@pytorch-bot pytorch-bot bot added ciflow/mps Run MPS tests (subset of trunk) module: cpu CPU specific problem (e.g., perf, algorithm) module: dynamo oncall: distributed Add this issue/PR to distributed oncall triage queue labels Sep 26, 2024
@sanchitintel sanchitintel force-pushed the cache_b_tiles_int8_woq_gemm branch from d0e7815 to 6d4497d Compare September 26, 2024 22:44
@sanchitintel sanchitintel marked this pull request as ready for review September 26, 2024 22:46
@sanchitintel sanchitintel removed oncall: distributed Add this issue/PR to distributed oncall triage queue ciflow/mps Run MPS tests (subset of trunk) labels Sep 26, 2024
@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

@pytorchmergebot
Copy link
Collaborator

Merge failed

Reason: 1 jobs have failed, first few of them are: inductor-periodic / cuda12.1-py3.10-gcc9-sm86-periodic-dynamo-benchmarks / test (aot_eager_torchbench, 2, 2, linux.g5.4xlarge.nvidia.gpu)

Details for Dev Infra team Raised by workflow job

@sanchitintel
Copy link
Collaborator Author

@pytorchbot merge

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

rahulsingh-intel pushed a commit to rahulsingh-intel/pytorch that referenced this pull request Nov 5, 2024
…rch#136688)

# Summary

The AMX ISA based GEMM micro-kernel template for int8 weight-only quantization (BF16 activation, int8 weights) should cache dequantized weights (int8 -> int32 -> fp32 -> bf16) so that they would not have to be dequantized again in subsequent calls to the _inner-kernel_ that uses the same weights.

This change leverages the fact that even for BF16 x BF16 GEMM template, cache-blocking ensures that `Nr * Kc` weight elements are cached in L1D cache (more info [here](https://static.sched.com/hosted_files/pytorch2024/59/TorchInductor%20CPU%20Backend%20Advancements%20-%20New%20Features%20and%20Performance%20Improvements_20240915.pdf)). Here, `Nr` is the register blocking size for `N` dimension (at the granularity of the GEMM micro-kernel, it's currently also the cache blocking size for `N` dimension, although that may change in the future), and `Kc` is the cache blocking size for `K` dimension.

The figure below is from the document linked above -

<img width="476" alt="image" src="https://github.com/user-attachments/assets/e23e5476-d910-46d1-a9b3-cbf77de76d94">

## Performance data

Collected on 48 physical cores of one socket of Intel Xeon  Platinum 8468H (Xeon SP 4th gen). Intel OpenMP & tcmalloc were preloaded.

|M | N | K | Latency with ATen _weight_int8pack_mm | Latency with codegened templated GEMM (current main branch) | Latency with codegened templated GEMM (this PR) |
|-----|-----|-----|------|----------|----|
|4096|4096|4096| 45.844 ms | 9.322 ms| 5.2181 ms |
|4096|11008|4096| 127.618 ms |24.6258 ms | 13.6046 ms|
|4096|4096|11008| 121.953 ms | 25.4692 ms | 10.2669 ms |
|4096|32000|4096| 478.450 ms| 75.3942 ms | 48.21 ms |

Pull Request resolved: pytorch#136688
Approved by: https://github.com/jgong5
pytorchmergebot pushed a commit that referenced this pull request Nov 7, 2024
…ntized (#139906)

@frost-intel discovered that some Inductor auto-tuning UTs for CPU are currently broken on machines supporting AMX ISA. That's because in #136688, I had reverted a change in the AMX GEMM micro-kernel that was introduced in #131887, but it looks like some other implementations introduced after the aforementioned change rely upon it, so it should not have been reverted.

Added a fix.

Ideally, a CI machine that supports AMX should cover these UTs (test/inductor/test_cpu_select_algorithm.py). We do have at least one CI machines that support AMX.

Pull Request resolved: #139906
Approved by: https://github.com/leslie-fang-intel, https://github.com/jgong5
@leslie-fang-intel
Copy link
Collaborator

@pytorchbot revert -m "correctness issue in #140208"

@pytorch-bot
Copy link

pytorch-bot bot commented Nov 9, 2024

❌ 🤖 pytorchbot command failed:

@pytorchbot revert: error: the following arguments are required: -c/--classification

usage: @pytorchbot revert -m MESSAGE -c
                          {nosignal,ignoredsignal,landrace,weird,ghfirst}

Try @pytorchbot --help for more info.

@leslie-fang-intel
Copy link
Collaborator

@pytorchbot revert -m "correctness issue in #140208" -c nosignal

@pytorchmergebot
Copy link
Collaborator

@pytorchbot successfully started a revert job. Check the current status here.
Questions? Feedback? Please reach out to the PyTorch DevX Team

@pytorchmergebot
Copy link
Collaborator

Reverting PR 136688 failed

Reason: Command git -C /home/runner/work/pytorch/pytorch revert --no-edit 3cbf0c0bbf7567816a3a2893d59b1d9bc936a90f returned non-zero exit code 1

Auto-merging test/inductor/test_cpu_select_algorithm.py
Auto-merging torch/_inductor/codegen/cpp_gemm_template.py
Auto-merging torch/_inductor/codegen/cpp_micro_gemm.py
CONFLICT (content): Merge conflict in torch/_inductor/codegen/cpp_micro_gemm.py
error: could not revert 3cbf0c0bbf7... [Inductor][CPP] Cache weight tiles in L1D for AMX int8 WoQ GEMM (#136688)
hint: After resolving the conflicts, mark them with
hint: "git add/rm <pathspec>", then run
hint: "git revert --continue".
hint: You can instead skip this commit with "git revert --skip".
hint: To abort and get back to the state before "git revert",
hint: run "git revert --abort".
hint: Disable this message with "git config advice.mergeConflict false"
Details for Dev Infra team Raised by workflow job

@leslie-fang-intel
Copy link
Collaborator

leslie-fang-intel commented Nov 9, 2024

  1. This PR seems caused correctness issue as reported in [Inductor][CPP] CPP GEMM Template WOQ int8 correctness failure #140208
  2. I would like to discuss more about the implementation in this PR. Considering here:
    load_dequantized_B(n);
    we are dealing with Matrix B of size Kc * Nr which should be contiguous after packing, why not we just create a dequantized_B_buf with same size/stride but different data type. Then we can dequant the elements from B into dequantized_B_buf with exactly same index. Written a POC here for this idea: leslie-fang-intel@4c83e4e. It can pass the correctness check of above 2 failed test case and seems simplifying the implementation in this PR.

@jgong5
Copy link
Collaborator

jgong5 commented Nov 11, 2024

Also, it seems we need more UTs to cover this feature...

@sanchitintel
Copy link
Collaborator Author

sanchitintel commented Nov 11, 2024

Also, it seems we need more UTs to cover this feature...

An existing UT caught it :(
We need CI coverage of AMX

@sanchitintel
Copy link
Collaborator Author

sanchitintel commented Nov 11, 2024

  1. This PR seems caused correctness issue as reported in [Inductor][CPP] CPP GEMM Template WOQ int8 correctness failure #140208

Thanks for the info! The size of the buffer should not have been multiplied with 2

we are dealing with Matrix B of size `Kc * Nr` which should be contiguous after packing, why not we just create a `dequantized_B_buf` with same size/stride but different data type. Then we can dequant the elements from `B` into `dequantized_B_buf` with exactly same index. Written a POC here for this idea: [leslie-fang-intel@4c83e4e](https://github.com/leslie-fang-intel/pytorch/commit/4c83e4e75138e8fa6e0d58438f75b7718dc8a0cc). It can pass the correctness check of above 2 failed test case and seems simplifying the implementation in this PR.

Thanks! The patch assumes that all elements of a B tile are contiguous in memory, as opposed to being a tile like this -

image

Are elements of B tiles copied to some intermediate buffer, so that they can be accessed contiguously? Thanks!

@chunyuan-w explained that the weights are being pre-packed to let each tile being accessed contiguously. So, I'm guessing when @jgong5 will add N != {{block_n}} implementation in the micro-kernel, then two consecutive tiles would be used. I'll revise the patch accordingly. Thanks!

pobin6 pushed a commit to pobin6/pytorch that referenced this pull request Dec 5, 2024
…ntized (pytorch#139906)

@frost-intel discovered that some Inductor auto-tuning UTs for CPU are currently broken on machines supporting AMX ISA. That's because in pytorch#136688, I had reverted a change in the AMX GEMM micro-kernel that was introduced in pytorch#131887, but it looks like some other implementations introduced after the aforementioned change rely upon it, so it should not have been reverted.

Added a fix.

Ideally, a CI machine that supports AMX should cover these UTs (test/inductor/test_cpu_select_algorithm.py). We do have at least one CI machines that support AMX.

Pull Request resolved: pytorch#139906
Approved by: https://github.com/leslie-fang-intel, https://github.com/jgong5
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ciflow/inductor ciflow/trunk Trigger trunk jobs on your pull request intel This tag is for PR from Intel Merged module: cpu CPU specific problem (e.g., perf, algorithm) module: dynamo module: inductor open source release notes: intel release notes category topic: performance topic category triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Projects

None yet

Development

Successfully merging this pull request may close these issues.

6 participants