-
Notifications
You must be signed in to change notification settings - Fork 25.7k
Alias for logsumexp to special namespace #58838
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
💊 CI failures summary and remediationsAs of commit 7de4af3 (more details on the Dr. CI page and at hud.pytorch.org/pr/58838): 💚 💚 Looks good so far! There are no failures yet. 💚 💚 Preview docs built from this PR This comment was automatically generated by Dr. CI (expand for details).Follow this link to opt-out of these comments for your Pull Requests.Please report bugs/suggestions to the (internal) Dr. CI Users group. |
|
Updates: Apologies for a long hold on this PR. Most of the issues have been resolved, waiting for the tests to pass then it will be ready for review. :) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Overall looks good to me. But one suggested change.
Thanks @krshrimali
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Left a few small points. Overall looks very good, thanks @krshrimali
torch/special/__init__.py
Outdated
| Example:: | ||
| >>> a = torch.randn(3, 3) | ||
| >>> torch.logsumexp(a, 1) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This should be torch.special.logsumexp.
It would also be more meaningful to write here:
>>> torch.dist(torch.special.logsumexp(a, 1), torch.log(torch.sum(torch.exp(a), dim)))
tensor(1.1921e-07)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm unfortunately on a different opinion with this change, I'm thinking that we should be consistent with examples on other functions and I feel having >>> torch.special.logsumexp(a, 1) is OK here.
Okay, I think your suggestion also makes sense, so to satisfy both simplicity as well as give a better meaning to what the function does. How about this instead?
>>> torch.special.logsumexp(a, 1)
#output
>>> torch.dist(torch.special.logsumexp(a, 1), torch.log(torch.sum(torch.exp(a), dim)))
#outputThere was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think having both would be a good addition. Note that the formula I wrote had an errata, and dim should be 1.
|
Gentle ping, @mruberry - Can you please take a look at this whenever you find time? Thanks! |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I do not know enough to be able to say whether the changes in the JIT are fine, I'll let @mruberry assess that or find the person that knows about this. Everything else looks good to me. Thank you Kush!
On this topic, I have a question. Would this fix happen to fix all (or a number of) those tests that we used to ignore in the OpInfos, or does it address something different? @krshrimali @kshitij12345
If it is so, this is great! We should then remove the relevan skips.
If it wasn't, how do other functions that have this problem deal with the JIT errors that we found here?
Thanks, @lezcano for taking a closer look at this PR. Appreciate your comments. Will update the doc to add both formulae.
This is a great question, the fix in this PR is for
Agreed! :) |
|
Update: There are currently 2 tests skipped:
Thanks for asking the right question, @lezcano! |
Codecov Report
@@ Coverage Diff @@
## master #58838 +/- ##
=======================================
Coverage 76.23% 76.24%
=======================================
Files 2054 2054
Lines 205033 205117 +84
=======================================
+ Hits 156306 156387 +81
- Misses 48727 48730 +3 |
torch/_torch_docs.py
Outdated
| >>> a = torch.randn(3, 3) | ||
| >>> torch.logsumexp(a, 1) | ||
| tensor([ 0.8442, 1.4322, 0.8711]) | ||
| Alias for :func:`torch.special.logsumexp`. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's keep the docs in torch for now and have special.logsumexp alias it.
| skips=( | ||
| # Expected a value of type 'int' for argument 'source' | ||
| # but instead found type 'list'. | ||
| SkipInfo('TestJit', 'test_jit_alias_remapping'), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Awesome skip removal
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for removing the skipped test 😍 😍 😍 LGTM if tests pass
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM! Thanks @krshrimali
|
@mruberry has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator. |
See #50345
cc: @kshitij12345 @lezcano @mruberry