KEMBAR78
Optimize mutable torch.library.custom_op overhead by zou3519 · Pull Request #139513 · pytorch/pytorch · GitHub
Skip to content

Conversation

@zou3519
Copy link
Contributor

@zou3519 zou3519 commented Nov 1, 2024

Stack from ghstack (oldest at bottom):

We don't need to do a loop over all the args, kwargs in the
AdInplaceOrView key; we just need to bump the version on the args,
kwargs that are mutable.

On the benchmark mentioned in
#139494
this made the time go from

mutate2 = 61.72943878173828
no_mutate2 = 36.89440155029297
mutate = 236.3092498779297
no_mutate = 59.31964874267578

to

mutate2 = 47.976478576660156
no_mutate2 = 38.37468719482422
mutate = 71.21315002441406
no_mutate = 59.7432975769043

Test Plan:

  • existing tests

We don't need to do a loop over all the args, kwargs in the
AdInplaceOrView key; we just need to bump the version on the args,
kwargs that are mutable.

On the benchmark mentioned in
#139494
this made the time go from
```
mutate2 = 61.72943878173828
no_mutate2 = 36.89440155029297
mutate = 236.3092498779297
no_mutate = 59.31964874267578

```
to
```
mutate2 = 47.976478576660156
no_mutate2 = 38.37468719482422
mutate = 71.21315002441406
no_mutate = 59.7432975769043
```

Test Plan:
- existing tests

[ghstack-poisoned]
@pytorch-bot
Copy link

pytorch-bot bot commented Nov 1, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/139513

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit 56c2229 with merge base 5e4c8b6 (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

zou3519 added a commit that referenced this pull request Nov 1, 2024
We don't need to do a loop over all the args, kwargs in the
AdInplaceOrView key; we just need to bump the version on the args,
kwargs that are mutable.

On the benchmark mentioned in
#139494
this made the time go from
```
mutate2 = 61.72943878173828
no_mutate2 = 36.89440155029297
mutate = 236.3092498779297
no_mutate = 59.31964874267578

```
to
```
mutate2 = 47.976478576660156
no_mutate2 = 38.37468719482422
mutate = 71.21315002441406
no_mutate = 59.7432975769043
```

Test Plan:
- existing tests

ghstack-source-id: 46c3beb
Pull Request resolved: #139513
for idx in mutated_idxs:
increment_version(args[idx])
for key in mutated_keys:
increment_version(kwargs[key])
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if we're worried about python overhead here, increment_version() should (as of recently) support Iterable[Tensor] as an argument: https://github.com/pytorch/pytorch/blob/main/torch/autograd/graph.py#L226

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nice, thanks for pointing that out

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's not clear to me if building an iterator to pass to increment_version is less expensive than calling increment_version in a loop, but I'll try it out

Copy link
Contributor

@bdhirsh bdhirsh left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just a side note, but - for compile, we probably shouldn't be running the ADInplaceOrView kernel at runtime (if we are then we should make sure that key is disabled when inductor runs). Since AOTAutograd handles bumping version counters in its epilogue.

@zou3519
Copy link
Contributor Author

zou3519 commented Nov 4, 2024

Just a side note, but - for compile, we probably shouldn't be running the ADInplaceOrView kernel at runtime (if we are then we should make sure that key is disabled when inductor runs). Since AOTAutograd handles bumping version counters in its epilogue.

Makes sense, let me file another issue

@zou3519 zou3519 added ciflow/trunk Trigger trunk jobs on your pull request release notes: composability release notes category labels Nov 5, 2024
@zou3519
Copy link
Contributor Author

zou3519 commented Nov 5, 2024

@pytorchbot merge

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

pobin6 pushed a commit to pobin6/pytorch that referenced this pull request Dec 5, 2024
We don't need to do a loop over all the args, kwargs in the
AdInplaceOrView key; we just need to bump the version on the args,
kwargs that are mutable.

On the benchmark mentioned in
pytorch#139494
this made the time go from
```
mutate2 = 61.72943878173828
no_mutate2 = 36.89440155029297
mutate = 236.3092498779297
no_mutate = 59.31964874267578

```
to
```
mutate2 = 47.976478576660156
no_mutate2 = 38.37468719482422
mutate = 71.21315002441406
no_mutate = 59.7432975769043
```

Test Plan:
- existing tests
Pull Request resolved: pytorch#139513
Approved by: https://github.com/bdhirsh
ghstack dependencies: pytorch#139509
@github-actions github-actions bot deleted the gh/zou3519/1086/head branch December 6, 2024 02:12
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ciflow/trunk Trigger trunk jobs on your pull request Merged release notes: composability release notes category

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants