-
Notifications
You must be signed in to change notification settings - Fork 25.7k
pow scalar exponent / base autodiff, fusion #19324
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
|
So one of the problems seems to be from Scalar<->float magic in symbolic_script, which seems to break when Scalar is an int. Unfortunately, the only thing I could think of is a gross hack adding torch._float to make a float from int/float IValues passed into a scalar. |
torch/csrc/jit/symbolic_script.cpp
Outdated
| exponent: float): | ||
| def backward(grad_output): | ||
| grad_self = torch.where(torch.tensor(exponent == 0.0), torch.zeros_like(self), grad_output * exponent * torch.pow(self, exponent - 1)) | ||
| if torch._float(exponent) == 0.0: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What's preventing float(exponent) instead?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
So what I think is going on: When creating the pow_0 operator for Tensor self, Scalar exponent, symbolic_script replaces Scalar by float. The float(exponent) cast gets eliminated because the JIT "knows" it is a float. What then happens is that unpacking Scalar IValues that the JIT thinks must be float but that are, in fact, int fails.
In a way it comes down to
// 2. to make sure the input of any graph node does not contain scalar type
// in its argument, all scalar arg should already be passed with float
// value since scalar/int aren't differentiable either way.
not being the complete picture because (as here) the scalar might not be the thing we want to differentiate for, but a parameter (in the mathematical sense as opposed to the variable) of the function we want to differentiate.
So I'm changing the patch to do the following:
If Scalar -> float conversion happened, I change back the input type of the graph to Scalar, and insert a conversion (prim::Float) as the first thing.
It'll be troublesome for use when we get operations that actually rely on the difference between float and int in Scalar ops, but currently we don't as far as I know.
I think this is a clean-up of the Scalar->float conversion as it ensures that the graph inputs actually match the schema. @ailzhang does that seem reasonable?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
So it turns out that re-Scalarizing breaks something (the double backward?). 😓
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
cc: @wanchaol added the scalar to float conversion.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Also, the scalar to float conversion only happens on the second pass, when allow_conversions is true, if there is an op defined for Scalar than that will be matched and there won't be a conversion.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@t-vi Actually as @eellison pointed out offline, we actually can get rid of the scalar -> float conversion entirely.
For example this works for me.
def pow_0(self,
exponent: number):
Note that we don't expose number in torchscript but we CAN compile it! :D With this we can easily get rid of current symbolic_variable.h and c10::ReplaceAll(schema_str, "Scalar", "float");.
Huge thanks to @eellison who pointed this out!
Let us know if this fixes your problem :D
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ha. That works (I think)! And I've been poking around much too long. Awesome.
Thanks so much @ailzhang and @eellison !
So my understanding is that with this, one would only have definitions that match the operator schemas. Could we actually check that? For now I left a fallback in but converted those replacements that were hit by a test_jit.py run.
torch/csrc/jit/symbolic_script.cpp
Outdated
|
|
||
| auto sym_script_it = schema_to_graphs.find(schema_str); | ||
|
|
||
| if (sym_script_it == schema_to_graphs.end()) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yea this if should be dropped before merging.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
So I removed it after double checking by means of calling sig(schema_string) and seeing whether it throws here:
pytorch/torch/csrc/jit/symbolic_script.cpp
Lines 1351 to 1353 in 3e0b46b
| auto schema_string = overloadedSchemaString(actual_schema); | |
| schema_to_graphs[schema_string] = std::move(pair); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM! Let me know when you are done with changing and want to merge it.
|
I think its good to merge.
Am 18. April 2019 20:27:33 MESZ schrieb Ailing <notifications@github.com>:
…ailzhang approved this pull request.
LGTM! Let me know when you are done with changing and want to merge it.
--
You are receiving this because you were mentioned.
Reply to this email directly or view it on GitHub:
#19324 (review)
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@ailzhang has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.
Summary: Fixes: pytorch#19253 Fixing pow(Tensor, float) is straightforward. The breakage for pow(float, Tensor) is a bit more subtle to trigger, and fixing needs `torch.log` (`math.log` didn't work) from the newly merged pytorch#19115 (Thanks ngimel for pointing out this has landed.) Pull Request resolved: pytorch#19324 Differential Revision: D15003531 Pulled By: ailzhang fbshipit-source-id: 8b22138fa27a43806b82886fb3a7b557bbb5a865
Fixes: #19253
Fixing pow(Tensor, float) is straightforward.
The breakage for pow(float, Tensor) is a bit more subtle to trigger, and fixing needs
torch.log(math.logdidn't work) from the newly merged #19115 (Thanks @ngimel for pointing out this has landed.)