-
Notifications
You must be signed in to change notification settings - Fork 25.7k
perf(inductor): use for loop with shortcut in Optimizers to speedup against list comprehensions (e.g. complex conversion)
#110613
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/110613
Note: Links to docs will display an error until the docs builds have been completed. ✅ You can merge normally! (2 Unrelated Failures)As of commit b123833 with merge base cf1b494 ( UNSTABLE - The following jobs failed but were likely due to flakiness present on trunk and has been marked as unstable:
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
Optimizers to speedup against list comprehensions (e.g. complex conversion)
|
Actually, I know how to speed this up even further. We can add a See the use of This is currently being canary tested in #110607 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Approval contingent on green CI :)
| ] | ||
| device_params = [torch.view_as_real(x) if torch.is_complex(x) else x for x in device_params] | ||
| for i in range(len(device_params)): | ||
| if torch.is_complex(device_params[i]): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nice
| for ((device_params, device_grads, device_state_sums, device_state_steps), _) in grouped_tensorlists.values(): | ||
|
|
||
| device_has_sparse_grad = any(grad.is_sparse for grad in device_grads) | ||
| device_has_sparse_grad = has_sparse_grad and any(grad.is_sparse for grad in device_grads) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
oh very good heuristical catch
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Anw, has_complex changes are orthogonal, so I'm also down to just merge this one first and rebase #110607 and further has_complex improvements on this PR.
|
Also, please land the Adagrad sparse fix separately |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hey! Don't land before addressing my new comment
|
@janeyx99 to be more explicit, this is how single_tensor case always shortcuts: Line 375 in 0296632
Line 262 in 0296632
Line 397 in 0296632
Line 251 in 0296632
pytorch/torch/optim/rmsprop.py Line 275 in 0296632
Line 238 in 0296632
|
|
@pytorchbot merge |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
…tcut (#110635) Partial fix: #110606 More on `has_complex` shortcut: #110613 (comment) CC: @janeyx99 @mlazos @lezcano Pull Request resolved: #110635 Approved by: https://github.com/lezcano
…cut (#110634) Partial fix: #110606 More on `has_complex` shortcut: #110613 (comment) CC: @janeyx99 @mlazos @lezcano Pull Request resolved: #110634 Approved by: https://github.com/lezcano
…_complex` shortcut (#110631) Partial fix: #110606 More on `has_complex` shortcut: #110613 (comment) CC: @janeyx99, @mlazos, @lezcano Pull Request resolved: #110631 Approved by: https://github.com/lezcano
Fully fixes: #110506
Depends: #110607
Potential merge conflicts:
Adagradwill usedevicewhencapturable- True always when compiling with dynamo #110339Adamaxto be better fused by Inductor and enable it #110345Related:
NAdam,RAdamand_multi_tensor_adadeltado not support complex types #110606 (we can apply the improvements here orthogonally to the complex support)Results
Benchmark: 100 params.
Breakdowns (float32, dynamo):
Notes:
_get_valuelist comprehension. Can be fixed in https://github.com/pytorch/pytorch/pull/110339/files by utilizing capturable pathcall_user_compilerfromcompile_innertiming.This PR:
Main
Seems that it doesn't help the complex case by much (but that's not the majority case). torch.float32 is generally positive, when it does not show drastic improvement / regresses, it is due to inductor variance (by manually inspecting the logs).
Benchmark Script
CC: @janeyx99 @mlazos