-
Notifications
You must be signed in to change notification settings - Fork 25.7k
[Inductor] Improve reinplace_scatters pass #112801
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
[ghstack-poisoned]
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/112801
Note: Links to docs will display an error until the docs builds have been completed. ✅ You can merge normally! (5 Unrelated Failures)As of commit 1acb9af with merge base 7715b47 ( FLAKY - The following job failed but was likely due to flakiness present on trunk:
BROKEN TRUNK - The following jobs failed but was present on the merge base:👉 Rebase onto the `viable/strict` branch to avoid these failures
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx peterbell10 ipiszy yf225 chenyang78 kadeng muchulee8 aakhundov ColinPeppler [ghstack-poisoned]
cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx peterbell10 ipiszy yf225 chenyang78 kadeng muchulee8 aakhundov ColinPeppler [ghstack-poisoned]
cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx peterbell10 ipiszy yf225 chenyang78 kadeng muchulee8 aakhundov ColinPeppler [ghstack-poisoned]
cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx peterbell10 ipiszy yf225 chenyang78 kadeng muchulee8 aakhundov ColinPeppler [ghstack-poisoned]
cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx peterbell10 ipiszy yf225 chenyang78 kadeng muchulee8 aakhundov ColinPeppler [ghstack-poisoned]
| if len(mutated_arg.users) > 1: # Arg used somewhere else | ||
| return False | ||
| return True | ||
| return not any_use_of_views_after_node(node, shared_view_nodes, None) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: pass in copy_node by kwarg here.
| shares storage with no other nodes. | ||
| Reinplaces scatter operations. | ||
| If there are no uses of a view of the mutated arg after the current node, | ||
| it is possible to inplace the op. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's worth adding a short description here of why we think this algorithm is correct.
| if any_use_of_views_after_node(node, shared_view_nodes, copy_node): | ||
| return False | ||
|
|
||
| graph.erase_node(copy_node) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
side nit: I think it's a bit weird for the node erasure to be done here. I would think that can_inplace is just a check, but it's also responsible here for doing the graph rewrite.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We could move it but then we need to have a ternary return value which is slightly more confusing
| return False | ||
| shared_view_nodes = storage_to_nodes[get_node_storage(mutated_arg)] | ||
| if mutated_arg.op == "placeholder": | ||
| if any(view.op == "placeholder" for view in shared_view_nodes): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why is this needed? It's ... plausibly correct to me, but did you find a case where it was needed?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, there was a test case that was failing that used some index function which emitted an unsqueeze.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What was the exact graph being produced?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The gist is
arg0 <- graph input
buf1 = unsquueze arg0
buf2 = indexput buf1
buf3 = squueze buf2
return buf3
arg0 and buf1 share the memory location but we check whether buf1 is placeholder, so we inplace it, however since arg0 shares the same memory location it would be invalid to inplace unless there was a copy at the epilogue as well.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Right, that makes sense. I think at some point when discussing this with you last week I thought of this case, but between then and now I forgot about it :P
cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx peterbell10 ipiszy yf225 chenyang78 kadeng muchulee8 aakhundov ColinPeppler [ghstack-poisoned]
cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx peterbell10 ipiszy yf225 chenyang78 kadeng muchulee8 aakhundov ColinPeppler [ghstack-poisoned]
cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx peterbell10 ipiszy yf225 chenyang78 kadeng muchulee8 aakhundov ColinPeppler [ghstack-poisoned]
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We should add a bit more correctness testing as well - the test_perf only tests the properties of the graph, but would be good to see some correctness.
Pull Request resolved: pytorch#112801 Approved by: https://github.com/Chillee, https://github.com/jansel ghstack dependencies: pytorch#112752, pytorch#113008
…ernels (pytorch#113056) Pull Request resolved: pytorch#113056 Approved by: https://github.com/jansel ghstack dependencies: pytorch#112752, pytorch#113008, pytorch#112801
Stack from ghstack (oldest at bottom):
cc @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @peterbell10 @ipiszy @yf225 @chenyang78 @kadeng @muchulee8 @aakhundov @ColinPeppler