KEMBAR78
feat(inductor): Add `RAdam` to Inductor by converting data-dependent control-flow to `torch.where` by jon-chuang · Pull Request #110351 · pytorch/pytorch · GitHub
Skip to content

Conversation

@jon-chuang
Copy link
Collaborator

@jon-chuang jon-chuang commented Oct 1, 2023

For small epochs (adaptive learning rate = 1), this will be more costly than not computing rect and adaptive_lr, but unlikely by much - they are all fused into a single kernel anyway.

cc @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @Xia-Weiwen @wenzhe-nrv @jiayisunx @peterbell10 @ipiszy @yf225 @chenyang78 @kadeng @muchulee8 @aakhundov @ColinPeppler

@pytorch-bot
Copy link

pytorch-bot bot commented Oct 1, 2023

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/110351

Note: Links to docs will display an error until the docs builds have been completed.

⏳ No Failures, 1 Pending

As of commit cf655e4 with merge base 4069d1d (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@vadimkantorov
Copy link
Contributor

vadimkantorov commented Oct 1, 2023

I hope at some point, we can have polymorphic torch.sqrt(1.0) -> 1.0 for python scalars just like in NumPy and torch.pi to not have to import math just for these simple things and have dispatch helpers :) or speed up https://github.com/jon-chuang/pytorch/blob/89eb7a75a251c41c4bee86e9ede1001b0d3998af/torch/optim/optimizer.py#L84 cpu scalar tensors

@jon-chuang jon-chuang force-pushed the jon-chuang/radam-inductor branch from 88c348b to a0f5b5b Compare October 1, 2023 14:14
if rho_t > 5
else 0
for rho_t in rho_t_list
torch.where(
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm, not sure this is a good idea in the eager case? Anyway, it executes on CPU, but still might be a bit slow

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Putting this on the capturable path would help resolve this as well

unrect_step_size = [(lr * _get_value(rect) / bc) * -1 for rect, bc in zip(unrectified, bias_correction1)]
bias_correction2_sqrt_times_rect_step_size = [
_dispatch_sqrt(1 - beta2 ** _get_value(step)) * (lr * rect / bc) * -1
_dispatch_sqrt(1 - beta2 ** _get_value(step)) * (lr * _get_value(rect) / bc) * -1
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should probably have a "capturable" code path similar to Adam/NAdam and now Adagrad (#110339) / Adamax

It can use foreach_sqrt for the GPU path? 🤔

The only reason we wouldn't use it here is because we want this bias correction to be on the Scalar codepath, but that doesn't really make that much sense since it is consumed in foreach below.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hmm, I think the issue with converting to foreach here is the _get_value part, but honestly I could do list comprehension to create the corresponding step values.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Adding a capturable code path may be the right move here. Then, we could get rid of the need for get_value as the capturable path should use the foreach ops, and the eager path can still keep scalar math on python.

@colesbury colesbury added the triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module label Oct 3, 2023
if not torch.jit.is_scripting() and isinstance(x, torch.Tensor):
return x.abs()
else:
return abs(x)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I improved dynamo's abs support in #110398 so this function shouldn't be needed any more.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @peterbell10! @jon-chuang let's rebase this PR onto that one and eradicate this?

Copy link
Contributor

@janeyx99 janeyx99 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agreed that a capturable path will more sense here. However, adding capturable well requires more work as you should also support it in the single tensor case and add the right test cases in test_cuda and test_optim.

if not torch.jit.is_scripting() and isinstance(x, torch.Tensor):
return x.abs()
else:
return abs(x)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @peterbell10! @jon-chuang let's rebase this PR onto that one and eradicate this?

if rho_t > 5
else 0
for rho_t in rho_t_list
torch.where(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Putting this on the capturable path would help resolve this as well

unrect_step_size = [(lr * _get_value(rect) / bc) * -1 for rect, bc in zip(unrectified, bias_correction1)]
bias_correction2_sqrt_times_rect_step_size = [
_dispatch_sqrt(1 - beta2 ** _get_value(step)) * (lr * rect / bc) * -1
_dispatch_sqrt(1 - beta2 ** _get_value(step)) * (lr * _get_value(rect) / bc) * -1
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Adding a capturable code path may be the right move here. Then, we could get rid of the need for get_value as the capturable path should use the foreach ops, and the eager path can still keep scalar math on python.

@vadimkantorov
Copy link
Contributor

I hope at some point, we can have polymorphic torch.sqrt

For this specific case, can torch.sqrt call be replaced by x ** 0.5? This expression should be both working for tensors and python scalars (does x ** 0.5 torch.pow call dispatch to torch.sqrt directly if power == 0.5?)

@github-actions
Copy link
Contributor

github-actions bot commented Dec 4, 2023

Looks like this PR hasn't been updated in a while so we're going to go ahead and mark this as Stale.
Feel free to remove the Stale label if you feel this was a mistake.
If you are unable to remove the Stale label please contact a maintainer in order to do so.
If you want the bot to never mark this PR stale again, add the no-stale label.
Stale pull requests will automatically be closed after 30 days of inactivity.

@github-actions github-actions bot added the Stale label Dec 4, 2023
@github-actions github-actions bot closed this Jan 3, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ciflow/inductor module: dynamo module: inductor open source release notes: optim Stale triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Projects

None yet

Development

Successfully merging this pull request may close these issues.

6 participants