KEMBAR78
Add Python serialization to Pattern Matcher patterns by eellison · Pull Request #108894 · pytorch/pytorch · GitHub
Skip to content

Conversation

@eellison
Copy link
Contributor

@eellison eellison commented Sep 8, 2023

Stack from ghstack (oldest at bottom):

Adds a Python Pretty Printer to the pattern matcher that serializes patterns as python. Generating our fuse attention patterns was taking 4 seconds of compile time, which will only get worse as we add more variants (which I will do in the rest of this stack). To write out patterns, build pytorch, then run gen_attention_patterns.py.

Since there is a line limit for PRs i'm only including the _sdpa_pattern1 in this first diff.

cc @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @Xia-Weiwen @wenzhe-nrv @jiayisunx @peterbell10 @ipiszy @ngimel @yf225 @chenyang78 @kadeng @muchulee8 @aakhundov

@pytorch-bot
Copy link

pytorch-bot bot commented Sep 8, 2023

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/108894

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit 1dd4628 with merge base 518308a (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx peterbell10 ipiszy ngimel yf225 chenyang78 kadeng muchulee8 aakhundov

[ghstack-poisoned]
@eellison eellison changed the title Add Pretty Python Print to serialize patterns Add Python serialization to Pattern Matcher patterns Sep 8, 2023
Adds a Python Pretty Printer to the pattern matcher that serializes patterns as python. Generating our fuse attention patterns was taking 4 seconds of compile time, which will only get worse as we add more variants (which I will do in the rest of this stack). To write out patterns, build pytorch, then run `gen_attention_patterns.py`.

Since there is a line limit for PRs  i'm only including the _sdpa_pattern1 in this first diff. Then I will include all inference patterns, then all training patterns.


Example Serialized Pattern:

```
tmp_0 = CallFunction(
    aten.view.default,
    CallFunction(aten.expand.default, KeywordArg("query"), Ignored()),
    Ignored(),
    _users=2,
)
tmp_1 = CallFunction(
    aten.view.default,
    CallFunction(
        aten.expand.default,
        CallFunction(aten.permute.default, KeywordArg("key"), Ignored()),
        Ignored(),
    ),
    Ignored(),
    _users=2,
)
tmp_2 = CallFunction(
    aten.div.Tensor,
    CallFunction(
        aten.view.default, CallFunction(aten.bmm.default, tmp_0, tmp_1), Ignored()
    ),
    KeywordArg("inv_scale"),
    _users=2,
)
tmp_3 = CallFunction(
    aten.exp.default,
    CallFunction(
        aten.sub.Tensor, tmp_2, CallFunction(aten.amax.default, tmp_2, Ignored(), True)
    ),
    _users=2,
)
tmp_4 = CallFunction(
    aten.div.Tensor,
    tmp_3,
    CallFunction(aten.sum.dim_IntList, tmp_3, Ignored(), True),
    _users=3,
)
tmp_5 = CallFunction(
    aten.view.default,
    CallFunction(aten.expand.default, tmp_4, Ignored()),
    Ignored(),
    _users=2,
)
tmp_6 = CallFunction(
    aten.view.default,
    CallFunction(aten.expand.default, KeywordArg("value"), Ignored()),
    Ignored(),
    _users=2,
)
tmp_7 = CallFunction(aten.view.default, KeywordArg("tangents_1"), Ignored(), _users=2)
tmp_8 = CallFunction(
    aten.mul.Tensor,
    CallFunction(
        aten.view.default,
        CallFunction(
            aten.bmm.default,
            tmp_7,
            CallFunction(aten.permute.default, tmp_6, Ignored()),
        ),
        Ignored(),
    ),
    tmp_4,
    _users=2,
)
tmp_9 = CallFunction(
    aten.view.default,
    CallFunction(
        aten.div.Tensor,
        CallFunction(
            aten.sub.Tensor,
            tmp_8,
            CallFunction(
                aten.mul.Tensor,
                tmp_4,
                CallFunction(aten.sum.dim_IntList, tmp_8, Ignored(), True),
            ),
        ),
        KeywordArg("inv_scale"),
    ),
    Ignored(),
    _users=2,
)
```


cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx peterbell10 ipiszy ngimel yf225 chenyang78 kadeng muchulee8 aakhundov

[ghstack-poisoned]
eellison added a commit that referenced this pull request Sep 8, 2023
ghstack-source-id: 3b70e17
Pull Request resolved: #108894
jansel
jansel previously requested changes Sep 8, 2023
Copy link
Contributor

@jansel jansel left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need a test to ensure the pattern file is up to date?

Adds a Python Pretty Printer to the pattern matcher that serializes patterns as python. Generating our fuse attention patterns was taking 4 seconds of compile time, which will only get worse as we add more variants (which I will do in the rest of this stack). To write out patterns, build pytorch, then run `gen_attention_patterns.py`.

Since there is a line limit for PRs  i'm only including the _sdpa_pattern1 in this first diff.

cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx peterbell10 ipiszy ngimel yf225 chenyang78 kadeng muchulee8 aakhundov

[ghstack-poisoned]
Adds a Python Pretty Printer to the pattern matcher that serializes patterns as python. Generating our fuse attention patterns was taking 4 seconds of compile time, which will only get worse as we add more variants (which I will do in the rest of this stack). To write out patterns, build pytorch, then run `gen_attention_patterns.py`.

Since there is a line limit for PRs  i'm only including the _sdpa_pattern1 in this first diff.

cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx peterbell10 ipiszy ngimel yf225 chenyang78 kadeng muchulee8 aakhundov

[ghstack-poisoned]
@eellison eellison dismissed jansel’s stale review September 11, 2023 19:00

Addressed comments

Adds a Python Pretty Printer to the pattern matcher that serializes patterns as python. Generating our fuse attention patterns was taking 4 seconds of compile time, which will only get worse as we add more variants (which I will do in the rest of this stack). To write out patterns, build pytorch, then run `gen_attention_patterns.py`.

Since there is a line limit for PRs  i'm only including the _sdpa_pattern1 in this first diff.

cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx peterbell10 ipiszy ngimel yf225 chenyang78 kadeng muchulee8 aakhundov

[ghstack-poisoned]
Adds a Python Pretty Printer to the pattern matcher that serializes patterns as python. Generating our fuse attention patterns was taking 4 seconds of compile time, which will only get worse as we add more variants (which I will do in the rest of this stack). To write out patterns, build pytorch, then run `gen_attention_patterns.py`.

Since there is a line limit for PRs  i'm only including the _sdpa_pattern1 in this first diff.

cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx peterbell10 ipiszy ngimel yf225 chenyang78 kadeng muchulee8 aakhundov

[ghstack-poisoned]
Adds a Python Pretty Printer to the pattern matcher that serializes patterns as python. Generating our fuse attention patterns was taking 4 seconds of compile time, which will only get worse as we add more variants (which I will do in the rest of this stack). To write out patterns, build pytorch, then run `gen_attention_patterns.py`.

Since there is a line limit for PRs  i'm only including the _sdpa_pattern1 in this first diff.

cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx peterbell10 ipiszy ngimel yf225 chenyang78 kadeng muchulee8 aakhundov

[ghstack-poisoned]
Adds a Python Pretty Printer to the pattern matcher that serializes patterns as python. Generating our fuse attention patterns was taking 4 seconds of compile time, which will only get worse as we add more variants (which I will do in the rest of this stack). To write out patterns, build pytorch, then run `gen_attention_patterns.py`.

Since there is a line limit for PRs  i'm only including the _sdpa_pattern1 in this first diff.

cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx peterbell10 ipiszy ngimel yf225 chenyang78 kadeng muchulee8 aakhundov

[ghstack-poisoned]
@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged immediately since you used the force (-f) flag, bypassing any CI checks (ETA: 1-5 minutes). Please use -f as last resort and instead consider -i/--ignore-current to continue the merge ignoring current failures. This will allow currently pending tests to finish and report signal before the merge.

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

pytorchmergebot pushed a commit that referenced this pull request Sep 19, 2023
Serializes the remaining traced patterns.

Pull Request resolved: #108917
Approved by: https://github.com/davidberard98
ghstack dependencies: #108894
pytorchmergebot pushed a commit that referenced this pull request Sep 19, 2023
aten.softmax will generate a different decomposition for fp16/bf16 and fp32 because when invoked in lower precision it will upcast the inputs to fp32 and then downcast after. This has been causing us to miss bf16 patterns. For example, Camembert improves 20% with this PR (as do I'm sure many other models).

Pull Request resolved: #109142
Approved by: https://github.com/yanboliang
ghstack dependencies: #108894, #108917
@eellison
Copy link
Contributor Author

@pytorchbot revert -m "land race" -c landrace

@eellison eellison reopened this Sep 19, 2023
@pytorchmergebot
Copy link
Collaborator

@pytorchbot successfully started a revert job. Check the current status here.
Questions? Feedback? Please reach out to the PyTorch DevX Team

@pytorchmergebot
Copy link
Collaborator

@eellison your PR has been successfully reverted.

pytorchmergebot added a commit that referenced this pull request Sep 19, 2023
Adds a Python Pretty Printer to the pattern matcher that serializes patterns as python. Generating our fuse attention patterns was taking 4 seconds of compile time, which will only get worse as we add more variants (which I will do in the rest of this stack). To write out patterns, build pytorch, then run `gen_attention_patterns.py`.

Since there is a line limit for PRs  i'm only including the _sdpa_pattern1 in this first diff.

cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx peterbell10 ipiszy ngimel yf225 chenyang78 kadeng muchulee8 aakhundov

[ghstack-poisoned]
Adds a Python Pretty Printer to the pattern matcher that serializes patterns as python. Generating our fuse attention patterns was taking 4 seconds of compile time, which will only get worse as we add more variants (which I will do in the rest of this stack). To write out patterns, build pytorch, then run `gen_attention_patterns.py`.

Since there is a line limit for PRs  i'm only including the _sdpa_pattern1 in this first diff.

cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx peterbell10 ipiszy ngimel yf225 chenyang78 kadeng muchulee8 aakhundov

[ghstack-poisoned]
Adds a Python Pretty Printer to the pattern matcher that serializes patterns as python. Generating our fuse attention patterns was taking 4 seconds of compile time, which will only get worse as we add more variants (which I will do in the rest of this stack). To write out patterns, build pytorch, then run `gen_attention_patterns.py`.

Since there is a line limit for PRs  i'm only including the _sdpa_pattern1 in this first diff.

cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx peterbell10 ipiszy ngimel yf225 chenyang78 kadeng muchulee8 aakhundov

[ghstack-poisoned]
@eellison
Copy link
Contributor Author

@pytorchbot merge

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

pytorchmergebot pushed a commit that referenced this pull request Sep 20, 2023
Serializes the remaining traced patterns.

Pull Request resolved: #108917
Approved by: https://github.com/davidberard98
ghstack dependencies: #109663, #108894
pytorchmergebot pushed a commit that referenced this pull request Sep 20, 2023
aten.softmax will generate a different decomposition for fp16/bf16 and fp32 because when invoked in lower precision it will upcast the inputs to fp32 and then downcast after. This has been causing us to miss bf16 patterns. For example, Camembert improves 20% with this PR (as do I'm sure many other models).

Pull Request resolved: #109142
Approved by: https://github.com/yanboliang
ghstack dependencies: #109663, #108894, #108917
pytorchmergebot pushed a commit that referenced this pull request Sep 20, 2023
Adds a 3d pattern that improves perf of HF Whisper from 1.3 -> 4.1. We could be matching more generally on 3d, but i'll leave that for another pr.

Thanks to @drisspg for helping me write the pattern.

Pull Request resolved: #109156
Approved by: https://github.com/yanboliang
ghstack dependencies: #109663, #108894, #108917, #109142
pytorchmergebot pushed a commit that referenced this pull request Sep 20, 2023
The pretty print is faster and more concise because it memoizes objects.

Pull Request resolved: #109066
Approved by: https://github.com/yanboliang
ghstack dependencies: #109663, #108894, #108917, #109142, #109156
@facebook-github-bot facebook-github-bot deleted the gh/eellison/529/head branch September 23, 2023 14:22
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants