-
Notifications
You must be signed in to change notification settings - Fork 25.7k
[fix] torch.{lin, log}space(): properly examine passed dtype #53685
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[fix] torch.{lin, log}space(): properly examine passed dtype #53685
Conversation
💊 CI failures summary and remediationsAs of commit 4bc39f0 (more details on the Dr. CI page):
1 failure not recognized by patterns:
This comment was automatically generated by Dr. CI (expand for details).Follow this link to opt-out of these comments for your Pull Requests.Please report bugs/suggestions to the (internal) Dr. CI Users group. |
if (_r.isNone(4)) { | ||
// aten::logspace(Scalar start, Scalar end, int? steps=None, float base=10.0, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor | ||
const auto options = TensorOptions() | ||
.dtype(_r.scalartypeOptional(5)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Everything is same here w.r.t generated code except for _r.scalartypeOptional(5)
instead of _r.scalartype(5)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ditto re explanatory comment
if (_r.isNone(3)) { | ||
// aten::linspace(Scalar start, Scalar end, int? steps=None, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor | ||
const auto options = TensorOptions() | ||
.dtype(_r.scalartypeOptional(4)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Everything is same here w.r.t generated code except for _r.scalartypeOptional(5) instead of _r.scalartype(5)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Might want to add a comment a la https://github.com/pytorch/pytorch/blob/master/tools/autograd/templates/python_torch_functions.cpp#L104 - helps explain why this needs to be a manual binding
@anjali411 Would you take the lead on reviewing this? |
@kshitij12345 sorry I missed the notification for this PR before. This PR looks great overall. |
const auto steps_ = steps.value_or(100); | ||
TORCH_CHECK(steps_ >= 0, "number of steps must be non-negative"); | ||
auto result_options = linspace_logspace_infer_options(start, end, options); | ||
auto result_options = linspace_logspace_infer_options(start, end, options, "torch.linspace()"); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@kshitij12345 so if the user doesn't specify the dtype, the dtype is not set until here, or the line below in which case we create a tensor with the default dtype?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Apologies missed the comment. Thanks @bhosmer for pointing out. Have updated linspace_logspace_infer_options
.
@anjali411 has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator. |
static PyObject * THPVariable_numel(PyObject* self_, PyObject* args, PyObject* kwargs); | ||
|
||
// linspace | ||
static PyObject * THPVariable_linspace(PyObject* self_, PyObject* args, PyObject* kwargs) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@bhosmer would you like to check these custom bindings?
Hey @kshitij12345, this looks correct but I'll need more time to review the new handwritten bindings; I cc'd @bhosmer to see if he'd like to take a look or suggest an alternative approach, too With the branch cut I'll need some extra time to review the new bindings carefully; would you ping me on this next week? I hate to make you wait but it's an incredibly busy time |
Linking with #56335, which also updates lin and log space |
@mruberry Sure. I understand. Will ping you by next week here! @mruberry @anjali411 Thanks for looking into it :) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hey @mruberry @anjali411 @kshitij12345 - yeah, unfortunately the codegen still locks you into the dtype defaulting behavior for ops that take TensorOptions args, so manual bindings are the way to make it work for now. Couple small suggestions inline, otherwise LGTM
if (_r.isNone(3)) { | ||
// aten::linspace(Scalar start, Scalar end, int? steps=None, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor | ||
const auto options = TensorOptions() | ||
.dtype(_r.scalartypeOptional(4)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Might want to add a comment a la https://github.com/pytorch/pytorch/blob/master/tools/autograd/templates/python_torch_functions.cpp#L104 - helps explain why this needs to be a manual binding
if (_r.isNone(4)) { | ||
// aten::logspace(Scalar start, Scalar end, int? steps=None, float base=10.0, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor | ||
const auto options = TensorOptions() | ||
.dtype(_r.scalartypeOptional(5)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ditto re explanatory comment
} | ||
|
||
return result_options; | ||
return options; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Per @anjali411's comment below, the other change introduced here besides the error checking is that the options we return will no longer have the default dtype patched in, in the non-complex case.
It may be that the current downstream consumer (empty
) happens to implement the same defaulting behavior atm, but I think it'll be more robust if we don't depend on that, and instead do it here.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Updated. Thanks!!
Codecov Report
@@ Coverage Diff @@
## master #53685 +/- ##
==========================================
+ Coverage 76.44% 76.80% +0.36%
==========================================
Files 2022 1986 -36
Lines 202375 198202 -4173
==========================================
- Hits 154701 152237 -2464
+ Misses 47674 45965 -1709 |
|
@mruberry Gentle ping :) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM! Thanks @kshitij12345 this is ready for land once you rebase it and update the comment for more context! :D
@anjali411 has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator. |
…ogspace-passed-dtype
@anjali411 gentle ping :) |
@anjali411 has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator. |
@anjali411 merged this pull request in c902609. |
Fixes #53171