-
Notifications
You must be signed in to change notification settings - Fork 25.7k
[Inductor] Improve RoPE #161420
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Inductor] Improve RoPE #161420
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/161420
Note: Links to docs will display an error until the docs builds have been completed. ✅ No FailuresAs of commit c121f59 with merge base 9491d28 ( This comment was automatically generated by Dr. CI and updates every 15 minutes. |
|
|
||
| # Threshold to decide if a kernel has small memory access in bytes | ||
| # Default value is 16 MB which is arbitrarily selected. | ||
| small_memory_access_threshold: int = 16777216 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
could we run some rough benchmarks on this threshold for rope if you haven't? It would be good to know in general.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
|
Sorry I still feel a bit skeptical here, can you clarify a bit?
|
|
Stamp it since it's off by default. But I'm not fully convenience to get 32us (1 us per layer) for the whole llama3-8b inference by adding this complexity to the compiler. Maybe try to find if the optimization can be more broadly applied |
|
Also make sure to address elias's comment above before tuning this on by default |
|
@pytorchbot merge |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
@shunting314 we have found that not to be the case in vLLM, and the extra kernel call is expensive, even with cudagraph enabled |
Can you elaborate a bit? Do you mean in cases when there are a lot of small kernels or in general? Benchmarking seems to show the cost is around 1 us. |
This PR fuses ROPE from 2 kernels into 1 kernel. Shape: ``` q: [B, Hq, S, D] k: [B, Hkv, S, D] ``` `Hq=32, Hkv=8, D=128` following Llama3 setting. <img width="980" height="624" alt="image" src="https://github.com/user-attachments/assets/652a8227-6f1d-465c-97fd-2b0af41f8ed9" /> Pull Request resolved: pytorch#161420 Approved by: https://github.com/shunting314
This PR fuses ROPE from 2 kernels into 1 kernel. Shape: ``` q: [B, Hq, S, D] k: [B, Hkv, S, D] ``` `Hq=32, Hkv=8, D=128` following Llama3 setting. <img width="980" height="624" alt="image" src="https://github.com/user-attachments/assets/652a8227-6f1d-465c-97fd-2b0af41f8ed9" /> Pull Request resolved: pytorch#161420 Approved by: https://github.com/shunting314
This PR fuses ROPE from 2 kernels into 1 kernel. Shape: ``` q: [B, Hq, S, D] k: [B, Hkv, S, D] ``` `Hq=32, Hkv=8, D=128` following Llama3 setting. <img width="980" height="624" alt="image" src="https://github.com/user-attachments/assets/652a8227-6f1d-465c-97fd-2b0af41f8ed9" /> Pull Request resolved: pytorch#161420 Approved by: https://github.com/shunting314
This PR fuses ROPE from 2 kernels into 1 kernel. Shape: ``` q: [B, Hq, S, D] k: [B, Hkv, S, D] ``` `Hq=32, Hkv=8, D=128` following Llama3 setting. <img width="980" height="624" alt="image" src="https://github.com/user-attachments/assets/652a8227-6f1d-465c-97fd-2b0af41f8ed9" /> Pull Request resolved: pytorch#161420 Approved by: https://github.com/shunting314
|
@shunting314 yeah if you look at vllm-project/vllm#22293, you can see that currently the generated sequence of 3 triton kernels causes a significant overhead. |
This PR fuses ROPE from 2 kernels into 1 kernel. Shape: ``` q: [B, Hq, S, D] k: [B, Hkv, S, D] ``` `Hq=32, Hkv=8, D=128` following Llama3 setting. <img width="980" height="624" alt="image" src="https://github.com/user-attachments/assets/652a8227-6f1d-465c-97fd-2b0af41f8ed9" /> Pull Request resolved: pytorch#161420 Approved by: https://github.com/shunting314

This PR fuses ROPE from 2 kernels into 1 kernel.
Shape:
Hq=32, Hkv=8, D=128following Llama3 setting.cc @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @ipiszy @chenyang78 @kadeng @muchulee8 @amjames @chauhang @aakhundov @coconutruben