-
Notifications
You must be signed in to change notification settings - Fork 25.7k
Add Python serialization to Pattern Matcher patterns #108894
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
[ghstack-poisoned]
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/108894
Note: Links to docs will display an error until the docs builds have been completed. ✅ No FailuresAs of commit 1dd4628 with merge base 518308a ( This comment was automatically generated by Dr. CI and updates every 15 minutes. |
cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx peterbell10 ipiszy ngimel yf225 chenyang78 kadeng muchulee8 aakhundov [ghstack-poisoned]
Adds a Python Pretty Printer to the pattern matcher that serializes patterns as python. Generating our fuse attention patterns was taking 4 seconds of compile time, which will only get worse as we add more variants (which I will do in the rest of this stack). To write out patterns, build pytorch, then run `gen_attention_patterns.py`.
Since there is a line limit for PRs i'm only including the _sdpa_pattern1 in this first diff. Then I will include all inference patterns, then all training patterns.
Example Serialized Pattern:
```
tmp_0 = CallFunction(
aten.view.default,
CallFunction(aten.expand.default, KeywordArg("query"), Ignored()),
Ignored(),
_users=2,
)
tmp_1 = CallFunction(
aten.view.default,
CallFunction(
aten.expand.default,
CallFunction(aten.permute.default, KeywordArg("key"), Ignored()),
Ignored(),
),
Ignored(),
_users=2,
)
tmp_2 = CallFunction(
aten.div.Tensor,
CallFunction(
aten.view.default, CallFunction(aten.bmm.default, tmp_0, tmp_1), Ignored()
),
KeywordArg("inv_scale"),
_users=2,
)
tmp_3 = CallFunction(
aten.exp.default,
CallFunction(
aten.sub.Tensor, tmp_2, CallFunction(aten.amax.default, tmp_2, Ignored(), True)
),
_users=2,
)
tmp_4 = CallFunction(
aten.div.Tensor,
tmp_3,
CallFunction(aten.sum.dim_IntList, tmp_3, Ignored(), True),
_users=3,
)
tmp_5 = CallFunction(
aten.view.default,
CallFunction(aten.expand.default, tmp_4, Ignored()),
Ignored(),
_users=2,
)
tmp_6 = CallFunction(
aten.view.default,
CallFunction(aten.expand.default, KeywordArg("value"), Ignored()),
Ignored(),
_users=2,
)
tmp_7 = CallFunction(aten.view.default, KeywordArg("tangents_1"), Ignored(), _users=2)
tmp_8 = CallFunction(
aten.mul.Tensor,
CallFunction(
aten.view.default,
CallFunction(
aten.bmm.default,
tmp_7,
CallFunction(aten.permute.default, tmp_6, Ignored()),
),
Ignored(),
),
tmp_4,
_users=2,
)
tmp_9 = CallFunction(
aten.view.default,
CallFunction(
aten.div.Tensor,
CallFunction(
aten.sub.Tensor,
tmp_8,
CallFunction(
aten.mul.Tensor,
tmp_4,
CallFunction(aten.sum.dim_IntList, tmp_8, Ignored(), True),
),
),
KeywordArg("inv_scale"),
),
Ignored(),
_users=2,
)
```
cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx peterbell10 ipiszy ngimel yf225 chenyang78 kadeng muchulee8 aakhundov
[ghstack-poisoned]
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do we need a test to ensure the pattern file is up to date?
torch/_inductor/fx_passes/serialized_attention_patterns/_sfdp_pattern_1.py
Outdated
Show resolved
Hide resolved
Adds a Python Pretty Printer to the pattern matcher that serializes patterns as python. Generating our fuse attention patterns was taking 4 seconds of compile time, which will only get worse as we add more variants (which I will do in the rest of this stack). To write out patterns, build pytorch, then run `gen_attention_patterns.py`. Since there is a line limit for PRs i'm only including the _sdpa_pattern1 in this first diff. cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx peterbell10 ipiszy ngimel yf225 chenyang78 kadeng muchulee8 aakhundov [ghstack-poisoned]
Adds a Python Pretty Printer to the pattern matcher that serializes patterns as python. Generating our fuse attention patterns was taking 4 seconds of compile time, which will only get worse as we add more variants (which I will do in the rest of this stack). To write out patterns, build pytorch, then run `gen_attention_patterns.py`. Since there is a line limit for PRs i'm only including the _sdpa_pattern1 in this first diff. cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx peterbell10 ipiszy ngimel yf225 chenyang78 kadeng muchulee8 aakhundov [ghstack-poisoned]
Adds a Python Pretty Printer to the pattern matcher that serializes patterns as python. Generating our fuse attention patterns was taking 4 seconds of compile time, which will only get worse as we add more variants (which I will do in the rest of this stack). To write out patterns, build pytorch, then run `gen_attention_patterns.py`. Since there is a line limit for PRs i'm only including the _sdpa_pattern1 in this first diff. cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx peterbell10 ipiszy ngimel yf225 chenyang78 kadeng muchulee8 aakhundov [ghstack-poisoned]
Adds a Python Pretty Printer to the pattern matcher that serializes patterns as python. Generating our fuse attention patterns was taking 4 seconds of compile time, which will only get worse as we add more variants (which I will do in the rest of this stack). To write out patterns, build pytorch, then run `gen_attention_patterns.py`. Since there is a line limit for PRs i'm only including the _sdpa_pattern1 in this first diff. cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx peterbell10 ipiszy ngimel yf225 chenyang78 kadeng muchulee8 aakhundov [ghstack-poisoned]
Adds a Python Pretty Printer to the pattern matcher that serializes patterns as python. Generating our fuse attention patterns was taking 4 seconds of compile time, which will only get worse as we add more variants (which I will do in the rest of this stack). To write out patterns, build pytorch, then run `gen_attention_patterns.py`. Since there is a line limit for PRs i'm only including the _sdpa_pattern1 in this first diff. cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx peterbell10 ipiszy ngimel yf225 chenyang78 kadeng muchulee8 aakhundov [ghstack-poisoned]
Adds a Python Pretty Printer to the pattern matcher that serializes patterns as python. Generating our fuse attention patterns was taking 4 seconds of compile time, which will only get worse as we add more variants (which I will do in the rest of this stack). To write out patterns, build pytorch, then run `gen_attention_patterns.py`. Since there is a line limit for PRs i'm only including the _sdpa_pattern1 in this first diff. cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx peterbell10 ipiszy ngimel yf225 chenyang78 kadeng muchulee8 aakhundov [ghstack-poisoned]
Merge startedYour change will be merged immediately since you used the force (-f) flag, bypassing any CI checks (ETA: 1-5 minutes). Please use Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
Serializes the remaining traced patterns. Pull Request resolved: #108917 Approved by: https://github.com/davidberard98 ghstack dependencies: #108894
aten.softmax will generate a different decomposition for fp16/bf16 and fp32 because when invoked in lower precision it will upcast the inputs to fp32 and then downcast after. This has been causing us to miss bf16 patterns. For example, Camembert improves 20% with this PR (as do I'm sure many other models). Pull Request resolved: #109142 Approved by: https://github.com/yanboliang ghstack dependencies: #108894, #108917
|
@pytorchbot revert -m "land race" -c landrace |
|
@pytorchbot successfully started a revert job. Check the current status here. |
|
@eellison your PR has been successfully reverted. |
This reverts commit 7db175b. Reverted #108894 on behalf of https://github.com/eellison due to land race ([comment](#108894 (comment)))
Adds a Python Pretty Printer to the pattern matcher that serializes patterns as python. Generating our fuse attention patterns was taking 4 seconds of compile time, which will only get worse as we add more variants (which I will do in the rest of this stack). To write out patterns, build pytorch, then run `gen_attention_patterns.py`. Since there is a line limit for PRs i'm only including the _sdpa_pattern1 in this first diff. cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx peterbell10 ipiszy ngimel yf225 chenyang78 kadeng muchulee8 aakhundov [ghstack-poisoned]
Adds a Python Pretty Printer to the pattern matcher that serializes patterns as python. Generating our fuse attention patterns was taking 4 seconds of compile time, which will only get worse as we add more variants (which I will do in the rest of this stack). To write out patterns, build pytorch, then run `gen_attention_patterns.py`. Since there is a line limit for PRs i'm only including the _sdpa_pattern1 in this first diff. cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx peterbell10 ipiszy ngimel yf225 chenyang78 kadeng muchulee8 aakhundov [ghstack-poisoned]
Adds a Python Pretty Printer to the pattern matcher that serializes patterns as python. Generating our fuse attention patterns was taking 4 seconds of compile time, which will only get worse as we add more variants (which I will do in the rest of this stack). To write out patterns, build pytorch, then run `gen_attention_patterns.py`. Since there is a line limit for PRs i'm only including the _sdpa_pattern1 in this first diff. cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx peterbell10 ipiszy ngimel yf225 chenyang78 kadeng muchulee8 aakhundov [ghstack-poisoned]
|
@pytorchbot merge |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
Serializes the remaining traced patterns. Pull Request resolved: #108917 Approved by: https://github.com/davidberard98 ghstack dependencies: #109663, #108894
aten.softmax will generate a different decomposition for fp16/bf16 and fp32 because when invoked in lower precision it will upcast the inputs to fp32 and then downcast after. This has been causing us to miss bf16 patterns. For example, Camembert improves 20% with this PR (as do I'm sure many other models). Pull Request resolved: #109142 Approved by: https://github.com/yanboliang ghstack dependencies: #109663, #108894, #108917
Adds a 3d pattern that improves perf of HF Whisper from 1.3 -> 4.1. We could be matching more generally on 3d, but i'll leave that for another pr. Thanks to @drisspg for helping me write the pattern. Pull Request resolved: #109156 Approved by: https://github.com/yanboliang ghstack dependencies: #109663, #108894, #108917, #109142
Stack from ghstack (oldest at bottom):
Adds a Python Pretty Printer to the pattern matcher that serializes patterns as python. Generating our fuse attention patterns was taking 4 seconds of compile time, which will only get worse as we add more variants (which I will do in the rest of this stack). To write out patterns, build pytorch, then run
gen_attention_patterns.py.Since there is a line limit for PRs i'm only including the _sdpa_pattern1 in this first diff.
cc @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @Xia-Weiwen @wenzhe-nrv @jiayisunx @peterbell10 @ipiszy @ngimel @yf225 @chenyang78 @kadeng @muchulee8 @aakhundov