-
Notifications
You must be signed in to change notification settings - Fork 25.7k
[inductor] Implement clone removal for user defined triton kernel via reinplace_scatters #111627
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
… reinplace_scatters [ghstack-poisoned]
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/111627
Note: Links to docs will display an error until the docs builds have been completed. ✅ No FailuresAs of commit e5236b0 with merge base bf01a7b ( This comment was automatically generated by Dr. CI and updates every 15 minutes. |
| dst.target == operator.getitem | ||
| and dst.args[0].kwargs["kwargs"][dst.args[1]] == node.args[0] | ||
| ): | ||
| dst = dst.args[0] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm a little iffy about this since dst becomes triton_kernel_wrapper_functional_proxy here but with above check that should be safe
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
seems a bit confusing to me, since getitem is the actual node that we're copying into the input, not triton_kernel_wrapper_functional_proxy. But I guess it's not wrong 🤔
| if node.target == aten.copy_.default: | ||
| copy_args_to_copy_nodes[(node.args[0], node.args[1])] = node | ||
| src = node.args[0] | ||
| dst = node.args[1] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
hmm isn't this flipped? I think the user code dst.copy_(src) will show up in the graph as:
torch.ops.aten.copy_.default(dst, src)
| return False | ||
|
|
||
| # Check for any uses other than current node and copy_ epilogue | ||
| if len(mutated_arg.users) > 2: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this check will miss some cases. In particular, mutated_arg can have more than 2 users, but it can still be ok to reinplace. Example:
def f(x, out):
# x will now have **three** users: `add`, `triton_kernel`, and `copy_`
tmp = torch.add(x, 1)
triton_kernel[grid](inp=x, out=out)
return out, tmp
Hmm, I think what we probably want is: if we look at users of mutated_arg later in the graph than the triton kernel, copy_() should be the only user (if this is the case, then it's safe to reinplace).
I'm not sure if there's a builtin way in FX to check "number of users after a given node", but there's a util here that does something similar: https://github.com/pytorch/pytorch/blob/main/torch/fx/passes/reinplace.py#L176
cc @Chillee (lmk if this sounds right to you)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah, @Chillee and I talked about this. As you mentioned, this could be improved by doing exactly what you suggested. I could do that as follow up.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think the concrete check is
"are there any views of this node that are used after the mutation"
But yeah, like Oguz said, mainly didn't do it out of laziness/simplicity.
| ): | ||
| return False | ||
|
|
||
| if len(shared_view_nodes) > 2: # Arg aliases another node other than copy_ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This check is probably also too conservative. if there happen to be any views of the input in the graph, then this will also be > 2 and we won't reinplace.
Handling the aliasing case is a bit tricky though (but definitely feels solvable). Do we want to try to handle reinplacing in 100% of cases as part of this PR? Followup also feels totally reasonable, since this seems more like it's just the existing state of the reinplacing pass.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In particular: we'd ideally like to remove the copy_() on x for a case like this:
def f(x, out):
x_view = x.view(-1)
triton_kernel[grid](inp=x, out=out)
out2 = x_view.mul(2)
return out, out2
But we want to make sure not to remove the copy_() for a case like this (where we actually do mutate the alias of x later):
def f(x, out):
x_view = x.view(-1)
triton_kernel[grid](inp=x, out=out)
x_view.mul_(2)
return out
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yup, this could be a good follow up too. For now, i kept the reinplace_scatters as is
… kernel via reinplace_scatters" cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx peterbell10 ipiszy yf225 chenyang78 kadeng muchulee8 aakhundov ColinPeppler [ghstack-poisoned]
… kernel via reinplace_scatters" cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx peterbell10 ipiszy yf225 chenyang78 kadeng muchulee8 aakhundov ColinPeppler [ghstack-poisoned]
… kernel via reinplace_scatters" cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx peterbell10 ipiszy yf225 chenyang78 kadeng muchulee8 aakhundov ColinPeppler [ghstack-poisoned]
… kernel via reinplace_scatters" cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx peterbell10 ipiszy yf225 chenyang78 kadeng muchulee8 aakhundov ColinPeppler [ghstack-poisoned]
|
@pytorchbot merge |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
… reinplace_scatters (pytorch#111627) Pull Request resolved: pytorch#111627 Approved by: https://github.com/jansel ghstack dependencies: pytorch#111434
… reinplace_scatters (pytorch#111627) Pull Request resolved: pytorch#111627 Approved by: https://github.com/jansel ghstack dependencies: pytorch#111434
Stack from ghstack (oldest at bottom):
cc @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @peterbell10 @ipiszy @yf225 @chenyang78 @kadeng @muchulee8 @aakhundov @ColinPeppler