KEMBAR78
[inductor] Make sure unfuse_addmm and addmm patterns don't overlap by peterbell10 · Pull Request #110235 · pytorch/pytorch · GitHub
Skip to content

Conversation

@peterbell10
Copy link
Collaborator

@peterbell10 peterbell10 commented Sep 28, 2023

Stack from ghstack (oldest at bottom):

Inductor has two opposing patterns,

addmm -> add + mm
add + mm -> addmm

This uses the extra_check to disable the addmm fusion pattern when the
heuristic to unfuse add is met, for consistency.

cc @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @Xia-Weiwen @wenzhe-nrv @jiayisunx @ipiszy @yf225 @chenyang78 @kadeng @muchulee8 @aakhundov @ColinPeppler

@pytorch-bot
Copy link

pytorch-bot bot commented Sep 28, 2023

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/110235

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit fab9c1e with merge base 419ec3b (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

Inductor has two opposing patterns,
```
addmm -> add + mm
add + mm -> addmm
```

This uses the `extra_check` to disable the addmm fusion pattern when the
heuristic to unfuse add is met, for consistency.

[ghstack-poisoned]
…verlap"

Inductor has two opposing patterns,
```
addmm -> add + mm
add + mm -> addmm
```

This uses the `extra_check` to disable the addmm fusion pattern when the
heuristic to unfuse add is met, for consistency.

cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx ipiszy yf225 chenyang78 kadeng muchulee8 aakhundov ColinPeppler

[ghstack-poisoned]
… overlap"

Inductor has two opposing patterns,
```
addmm -> add + mm
add + mm -> addmm
```

This uses the `extra_check` to disable the addmm fusion pattern when the
heuristic to unfuse add is met, for consistency.

cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx ipiszy yf225 chenyang78 kadeng muchulee8 aakhundov ColinPeppler

[ghstack-poisoned]
peterbell10 added a commit that referenced this pull request Sep 29, 2023
Inductor has two opposing patterns,
```
addmm -> add + mm
add + mm -> addmm
```

This uses the `extra_check` to disable the addmm fusion pattern when the
heuristic to unfuse add is met, for consistency.

ghstack-source-id: 3f78107
Pull Request resolved: #110235
@peterbell10 peterbell10 marked this pull request as ready for review September 29, 2023 12:39
torch.testing.assert_close(a2, e2)
self.assertEqual(counters["inductor"]["pattern_matcher_count"], 2)
self.assertEqual(counters["inductor"]["pattern_matcher_nodes"], 4)
count, nodes = (2, 4) if should_fuse else (0, 0)
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note that these cases weren't actually fused previously, it's just that the pattern replaced them with a lowering that did add + mm.

Copy link
Collaborator

@lezcano lezcano left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fair enough. I think the code has a preexisting issue that we should fix tho.

… overlap"

Inductor has two opposing patterns,
```
addmm -> add + mm
add + mm -> addmm
```

This uses the `extra_check` to disable the addmm fusion pattern when the
heuristic to unfuse add is met, for consistency.

cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx ipiszy yf225 chenyang78 kadeng muchulee8 aakhundov ColinPeppler

[ghstack-poisoned]
Copy link
Collaborator

@lezcano lezcano left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Even better

@peterbell10
Copy link
Collaborator Author

@pytorchbot merge

@pytorch-bot pytorch-bot bot added the ciflow/trunk Trigger trunk jobs on your pull request label Sep 29, 2023
@pytorchmergebot
Copy link
Collaborator

Merge failed

Reason: This PR needs a release notes: label
If your changes are user facing and intended to be a part of release notes, please use a label starting with release notes:.

If not, please add the topic: not user facing label.

To add a label, you can comment to pytorchbot, for example
@pytorchbot label "topic: not user facing"

For more information, see
https://github.com/pytorch/pytorch/wiki/PyTorch-AutoLabel-Bot#why-categorize-for-release-notes-and-how-does-it-work.

Details for Dev Infra team Raised by workflow job

@peterbell10
Copy link
Collaborator Author

@pytorchbot merge

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

Copy link
Contributor

@eellison eellison left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good just one comment about checking for input being a tensor

Comment on lines -462 to -470
def addmm(match, mat1, mat2, inp):
if isinstance(inp, ir.TensorBox):
inp_shape = inp.get_size()
matched = len(inp_shape) <= 2
mm_shape = shape_of_mm(mat1, mat2)
for i, m in zip(inp_shape, mm_shape):
matched &= i == 1 or i == m
else: # inp is a Number
matched = False
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice to move this away from graph lowering pattern, this was overdue..

return not should_prefer_unfused_addmm(match)


@register_graph_pattern(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cc @yanboliang @jansel we should have some sort of commutative concept that would avoid this duplication

Comment on lines 787 to 788
if not isinstance(inp, torch.fx.Node):
return False # Input is a number
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i've made this check before which was fixed by #108160, you can have a fx.Node input which is a SymInt/SymFloat

…atterns don't overlap"

Inductor has two opposing patterns,
```
addmm -> add + mm
add + mm -> addmm
```

This uses the `extra_check` to disable the addmm fusion pattern when the
heuristic to unfuse add is met, for consistency.

cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx ipiszy yf225 chenyang78 kadeng muchulee8 aakhundov ColinPeppler

[ghstack-poisoned]
peterbell10 added a commit that referenced this pull request Sep 29, 2023
Inductor has two opposing patterns,
```
addmm -> add + mm
add + mm -> addmm
```

This uses the `extra_check` to disable the addmm fusion pattern when the
heuristic to unfuse add is met, for consistency.

ghstack-source-id: f1ea08f
Pull Request resolved: #110235
@pytorchmergebot
Copy link
Collaborator

Merge failed

Reason: New commits were pushed while merging. Please rerun the merge command.

Details for Dev Infra team Raised by workflow job

@peterbell10
Copy link
Collaborator Author

@pytorchbot merge

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

@facebook-github-bot facebook-github-bot deleted the gh/peterbell10/625/head branch October 3, 2023 14:23
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants