-
Notifications
You must be signed in to change notification settings - Fork 25.7k
feat(inductor): Add RAdam to Inductor by converting data-dependent control-flow to torch.where
#110351
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/110351
Note: Links to docs will display an error until the docs builds have been completed. ⏳ No Failures, 1 PendingAs of commit cf655e4 with merge base 4069d1d ( This comment was automatically generated by Dr. CI and updates every 15 minutes. |
|
I hope at some point, we can have polymorphic torch.sqrt(1.0) -> 1.0 for python scalars just like in NumPy and torch.pi to not have to import math just for these simple things and have dispatch helpers :) or speed up https://github.com/jon-chuang/pytorch/blob/89eb7a75a251c41c4bee86e9ede1001b0d3998af/torch/optim/optimizer.py#L84 cpu scalar tensors |
88c348b to
a0f5b5b
Compare
…uang/radam-inductor
…uang/radam-inductor
| if rho_t > 5 | ||
| else 0 | ||
| for rho_t in rho_t_list | ||
| torch.where( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hmm, not sure this is a good idea in the eager case? Anyway, it executes on CPU, but still might be a bit slow
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Putting this on the capturable path would help resolve this as well
| unrect_step_size = [(lr * _get_value(rect) / bc) * -1 for rect, bc in zip(unrectified, bias_correction1)] | ||
| bias_correction2_sqrt_times_rect_step_size = [ | ||
| _dispatch_sqrt(1 - beta2 ** _get_value(step)) * (lr * rect / bc) * -1 | ||
| _dispatch_sqrt(1 - beta2 ** _get_value(step)) * (lr * _get_value(rect) / bc) * -1 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We should probably have a "capturable" code path similar to Adam/NAdam and now Adagrad (#110339) / Adamax
It can use foreach_sqrt for the GPU path? 🤔
The only reason we wouldn't use it here is because we want this bias correction to be on the Scalar codepath, but that doesn't really make that much sense since it is consumed in foreach below.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
hmm, I think the issue with converting to foreach here is the _get_value part, but honestly I could do list comprehension to create the corresponding step values.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Adding a capturable code path may be the right move here. Then, we could get rid of the need for get_value as the capturable path should use the foreach ops, and the eager path can still keep scalar math on python.
| if not torch.jit.is_scripting() and isinstance(x, torch.Tensor): | ||
| return x.abs() | ||
| else: | ||
| return abs(x) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I improved dynamo's abs support in #110398 so this function shouldn't be needed any more.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks @peterbell10! @jon-chuang let's rebase this PR onto that one and eradicate this?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Agreed that a capturable path will more sense here. However, adding capturable well requires more work as you should also support it in the single tensor case and add the right test cases in test_cuda and test_optim.
| if not torch.jit.is_scripting() and isinstance(x, torch.Tensor): | ||
| return x.abs() | ||
| else: | ||
| return abs(x) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks @peterbell10! @jon-chuang let's rebase this PR onto that one and eradicate this?
| if rho_t > 5 | ||
| else 0 | ||
| for rho_t in rho_t_list | ||
| torch.where( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Putting this on the capturable path would help resolve this as well
| unrect_step_size = [(lr * _get_value(rect) / bc) * -1 for rect, bc in zip(unrectified, bias_correction1)] | ||
| bias_correction2_sqrt_times_rect_step_size = [ | ||
| _dispatch_sqrt(1 - beta2 ** _get_value(step)) * (lr * rect / bc) * -1 | ||
| _dispatch_sqrt(1 - beta2 ** _get_value(step)) * (lr * _get_value(rect) / bc) * -1 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Adding a capturable code path may be the right move here. Then, we could get rid of the need for get_value as the capturable path should use the foreach ops, and the eager path can still keep scalar math on python.
For this specific case, can torch.sqrt call be replaced by |
|
Looks like this PR hasn't been updated in a while so we're going to go ahead and mark this as |
For small epochs (adaptive learning rate = 1), this will be more costly than not computing
rectandadaptive_lr, but unlikely by much - they are all fused into a single kernel anyway.cc @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @Xia-Weiwen @wenzhe-nrv @jiayisunx @peterbell10 @ipiszy @yf225 @chenyang78 @kadeng @muchulee8 @aakhundov @ColinPeppler