KEMBAR78
Simplify & rectify dequantized B buffer loading for AMX GEMM micro-kernel for WoQ int8 case by sanchitintel · Pull Request #140258 · pytorch/pytorch · GitHub
Skip to content

Conversation

@sanchitintel
Copy link
Collaborator

@sanchitintel sanchitintel commented Nov 11, 2024

As suggested by @leslie-fang-intel in leslie-fang-intel@4c83e4e#diff-139642bd981df977f70f4c18c1c34bd1a85c1d6b9ffa06aaa98426ed83942a31R537 - all elements of B tiles (not referring to AMX tiles, but the tiles at the granularity of the micro-kernel) have contiguous elements since B matrix is pre-packed, so dequantized buffer loading logic can be simplified. While the previous approach kept elements to be loaded into a B AMX tile contiguous, the new approach doesn't entail any performance penalty either because that data is already in L1D, so loading AMX tiles from non-contiguous dequantized B elements doesn't adversely affect performance.

Also rectified the size of the dequantized B buffer.

Fixes #140208.

A subsequent PR will factor out caching of dequantized int8 weights into a separate codegen function

cc @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @ipiszy @yf225 @chenyang78 @kadeng @muchulee8 @ColinPeppler @amjames @desertfire @chauhang @aakhundov

@pytorch-bot
Copy link

pytorch-bot bot commented Nov 11, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/140258

Note: Links to docs will display an error until the docs builds have been completed.

❗ 1 Active SEVs

There are 1 currently active SEVs. If your PR is affected, please view them below:

✅ You can merge normally! (1 Unrelated Failure)

As of commit 709adfe with merge base fa63276 (image):

BROKEN TRUNK - The following job failed but were present on the merge base:

👉 Rebase onto the `viable/strict` branch to avoid these failures

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@sanchitintel sanchitintel added the topic: bug fixes topic category label Nov 11, 2024
@sanchitintel sanchitintel changed the title Simplify B tile loading logic for AMX GEMM micro-kernel for WoQ int8 case Simplify & rectify B tile loading logic for AMX GEMM micro-kernel for WoQ int8 case Nov 11, 2024
@sanchitintel sanchitintel added the topic: not user facing topic category label Nov 11, 2024
@sanchitintel sanchitintel changed the title Simplify & rectify B tile loading logic for AMX GEMM micro-kernel for WoQ int8 case Simplify & rectify dequantized B buffer loading for AMX GEMM micro-kernel for WoQ int8 case Nov 11, 2024
@leslie-fang-intel
Copy link
Collaborator

BTW: I think horizontal transverse doesn't work well with this cache optimization cc @jgong5 @chunyuan-w

@sanchitintel
Copy link
Collaborator Author

sanchitintel commented Nov 13, 2024

BTW: I think horizontal transverse doesn't work well with this cache optimization cc @jgong5 @chunyuan-w

Hi, would the horizontal traverse strategy complement the existing AMX GEMM micro-kernel template (by conditionally using it), or would it replace it? Thanks!

@leslie-fang-intel
Copy link
Collaborator

Hi, would the horizontal traverse strategy complement the existing AMX GEMM micro-kernel template (by conditionally using it), or would it replace it? Thanks!

I think we will use it conditionally

@ezyang ezyang added the triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module label Nov 14, 2024
@sanchitintel sanchitintel requested a review from jgong5 November 18, 2024 09:15
Copy link
Collaborator

@jgong5 jgong5 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As we discussed offline, please do not assume the B is contiguous.

@sanchitintel sanchitintel requested a review from jgong5 November 19, 2024 21:54
@pytorchmergebot
Copy link
Collaborator

@pytorchbot started a rebase job onto refs/remotes/origin/main. Check the current status here

As suggested by @leslie-fang-intel in https://github.com/leslie-fang-intel in /pytorch/commit/4c83e4e75138e8fa6e0d58438f75b7718dc8a0cc#diff-139642bd981df977f70f4c18c1c34bd1a85c1d6b9ffa06aaa98426ed83942a31R537
This case cannot be covered by the current UTs, since it hasn't been implemented
This case can't be tested, though, as N != block_n case has not been implemented.
Don't assume weight-packing at GEMM template level
Its value would also be known at runtime, so it wouldn't affect performance
@pytorchmergebot
Copy link
Collaborator

Successfully rebased sanchitj/simplify_amx_tile_load onto refs/remotes/origin/main, please pull locally before adding more changes (for example, via git checkout sanchitj/simplify_amx_tile_load && git pull --rebase)

@pytorchmergebot pytorchmergebot force-pushed the sanchitj/simplify_amx_tile_load branch from 3dffe41 to 709adfe Compare November 21, 2024 21:26
@sanchitintel
Copy link
Collaborator Author

@pytorchbot merge

@pytorch-bot pytorch-bot bot added the ciflow/trunk Trigger trunk jobs on your pull request label Nov 21, 2024
@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

pobin6 pushed a commit to pobin6/pytorch that referenced this pull request Dec 5, 2024
…rnel for WoQ int8 case (pytorch#140258)

As suggested by @leslie-fang-intel in leslie-fang-intel@4c83e4e#diff-139642bd981df977f70f4c18c1c34bd1a85c1d6b9ffa06aaa98426ed83942a31R537 - all elements of `B` tiles (not referring to AMX tiles, but the tiles at the granularity of the micro-kernel) have contiguous elements since `B` matrix is pre-packed, so dequantized buffer loading logic can be simplified. While the previous approach kept elements to be loaded into a B AMX tile contiguous, the new approach doesn't entail any performance penalty either because that data is already in L1D, so loading AMX tiles from non-contiguous dequantized B elements doesn't adversely affect performance.

Also rectified the size of the dequantized B buffer.

Fixes pytorch#140208.

A subsequent PR will factor out caching of dequantized int8 weights into a separate codegen function

Pull Request resolved: pytorch#140258
Approved by: https://github.com/jgong5, https://github.com/leslie-fang-intel
@github-actions github-actions bot deleted the sanchitj/simplify_amx_tile_load branch December 22, 2024 02:11
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ciflow/inductor ciflow/trunk Trigger trunk jobs on your pull request Merged module: inductor open source topic: bug fixes topic category topic: not user facing topic category triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Projects

None yet

Development

Successfully merging this pull request may close these issues.

[Inductor][CPP] CPP GEMM Template WOQ int8 correctness failure

6 participants