KEMBAR78
[inductor] Implement clone removal for user defined triton kernel via reinplace_scatters by oulgen · Pull Request #111627 · pytorch/pytorch · GitHub
Skip to content

Conversation

@pytorch-bot
Copy link

pytorch-bot bot commented Oct 20, 2023

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/111627

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit e5236b0 with merge base bf01a7b (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

oulgen added a commit that referenced this pull request Oct 20, 2023
… reinplace_scatters

ghstack-source-id: 1d009e4
Pull Request resolved: #111627
dst.target == operator.getitem
and dst.args[0].kwargs["kwargs"][dst.args[1]] == node.args[0]
):
dst = dst.args[0]
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm a little iffy about this since dst becomes triton_kernel_wrapper_functional_proxy here but with above check that should be safe

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

seems a bit confusing to me, since getitem is the actual node that we're copying into the input, not triton_kernel_wrapper_functional_proxy. But I guess it's not wrong 🤔

@oulgen oulgen added the ciflow/trunk Trigger trunk jobs on your pull request label Oct 20, 2023
if node.target == aten.copy_.default:
copy_args_to_copy_nodes[(node.args[0], node.args[1])] = node
src = node.args[0]
dst = node.args[1]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hmm isn't this flipped? I think the user code dst.copy_(src) will show up in the graph as:

torch.ops.aten.copy_.default(dst, src)

return False

# Check for any uses other than current node and copy_ epilogue
if len(mutated_arg.users) > 2:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this check will miss some cases. In particular, mutated_arg can have more than 2 users, but it can still be ok to reinplace. Example:

def f(x, out):
    # x will now have **three** users: `add`, `triton_kernel`, and `copy_`
    tmp = torch.add(x, 1)
    triton_kernel[grid](inp=x, out=out)
    return out, tmp

Hmm, I think what we probably want is: if we look at users of mutated_arg later in the graph than the triton kernel, copy_() should be the only user (if this is the case, then it's safe to reinplace).

I'm not sure if there's a builtin way in FX to check "number of users after a given node", but there's a util here that does something similar: https://github.com/pytorch/pytorch/blob/main/torch/fx/passes/reinplace.py#L176

cc @Chillee (lmk if this sounds right to you)

Copy link
Contributor Author

@oulgen oulgen Oct 20, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, @Chillee and I talked about this. As you mentioned, this could be improved by doing exactly what you suggested. I could do that as follow up.

Copy link
Collaborator

@Chillee Chillee Oct 20, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the concrete check is

"are there any views of this node that are used after the mutation"

But yeah, like Oguz said, mainly didn't do it out of laziness/simplicity.

):
return False

if len(shared_view_nodes) > 2: # Arg aliases another node other than copy_
Copy link
Contributor

@bdhirsh bdhirsh Oct 20, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This check is probably also too conservative. if there happen to be any views of the input in the graph, then this will also be > 2 and we won't reinplace.

Handling the aliasing case is a bit tricky though (but definitely feels solvable). Do we want to try to handle reinplacing in 100% of cases as part of this PR? Followup also feels totally reasonable, since this seems more like it's just the existing state of the reinplacing pass.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In particular: we'd ideally like to remove the copy_() on x for a case like this:

def f(x, out):
    x_view = x.view(-1)
    triton_kernel[grid](inp=x, out=out)
    out2 = x_view.mul(2)
    return out, out2

But we want to make sure not to remove the copy_() for a case like this (where we actually do mutate the alias of x later):

def f(x, out):
    x_view = x.view(-1)
    triton_kernel[grid](inp=x, out=out)
    x_view.mul_(2)
    return out

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yup, this could be a good follow up too. For now, i kept the reinplace_scatters as is

… kernel via reinplace_scatters"

cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx peterbell10 ipiszy yf225 chenyang78 kadeng muchulee8 aakhundov ColinPeppler

[ghstack-poisoned]
… kernel via reinplace_scatters"

cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx peterbell10 ipiszy yf225 chenyang78 kadeng muchulee8 aakhundov ColinPeppler

[ghstack-poisoned]
oulgen added a commit that referenced this pull request Oct 20, 2023
… reinplace_scatters

ghstack-source-id: d236b82
Pull Request resolved: #111627
… kernel via reinplace_scatters"

cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx peterbell10 ipiszy yf225 chenyang78 kadeng muchulee8 aakhundov ColinPeppler

[ghstack-poisoned]
… kernel via reinplace_scatters"

cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx peterbell10 ipiszy yf225 chenyang78 kadeng muchulee8 aakhundov ColinPeppler

[ghstack-poisoned]
oulgen added a commit that referenced this pull request Oct 22, 2023
… reinplace_scatters

ghstack-source-id: 80c072d
Pull Request resolved: #111627
@oulgen oulgen added the topic: not user facing topic category label Oct 22, 2023
@oulgen
Copy link
Contributor Author

oulgen commented Oct 22, 2023

@pytorchbot merge

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

@facebook-github-bot facebook-github-bot deleted the gh/oulgen/10/head branch October 26, 2023 14:24
xuhancn pushed a commit to xuhancn/pytorch that referenced this pull request Nov 7, 2023
Skylion007 pushed a commit to Skylion007/pytorch that referenced this pull request Nov 14, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants