KEMBAR78
pow scalar exponent / base autodiff, fusion by t-vi · Pull Request #19324 · pytorch/pytorch · GitHub
Skip to content

Conversation

@t-vi
Copy link
Collaborator

@t-vi t-vi commented Apr 16, 2019

Fixes: #19253

Fixing pow(Tensor, float) is straightforward.
The breakage for pow(float, Tensor) is a bit more subtle to trigger, and fixing needs torch.log (math.log didn't work) from the newly merged #19115 (Thanks @ngimel for pointing out this has landed.)

@facebook-github-bot facebook-github-bot added the oncall: jit Add this issue/PR to JIT oncall triage queue label Apr 16, 2019
@t-vi
Copy link
Collaborator Author

t-vi commented Apr 16, 2019

So one of the problems seems to be from Scalar<->float magic in symbolic_script, which seems to break when Scalar is an int. Unfortunately, the only thing I could think of is a gross hack adding torch._float to make a float from int/float IValues passed into a scalar.

exponent: float):
def backward(grad_output):
grad_self = torch.where(torch.tensor(exponent == 0.0), torch.zeros_like(self), grad_output * exponent * torch.pow(self, exponent - 1))
if torch._float(exponent) == 0.0:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What's preventing float(exponent) instead?

Copy link
Collaborator Author

@t-vi t-vi Apr 17, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So what I think is going on: When creating the pow_0 operator for Tensor self, Scalar exponent, symbolic_script replaces Scalar by float. The float(exponent) cast gets eliminated because the JIT "knows" it is a float. What then happens is that unpacking Scalar IValues that the JIT thinks must be float but that are, in fact, int fails.

In a way it comes down to

    // 2. to make sure the input of any graph node does not contain scalar type
    //    in its argument, all scalar arg should already be passed with float
    //    value since scalar/int aren't differentiable either way.

not being the complete picture because (as here) the scalar might not be the thing we want to differentiate for, but a parameter (in the mathematical sense as opposed to the variable) of the function we want to differentiate.

So I'm changing the patch to do the following:
If Scalar -> float conversion happened, I change back the input type of the graph to Scalar, and insert a conversion (prim::Float) as the first thing.
It'll be troublesome for use when we get operations that actually rely on the difference between float and int in Scalar ops, but currently we don't as far as I know.

I think this is a clean-up of the Scalar->float conversion as it ensures that the graph inputs actually match the schema. @ailzhang does that seem reasonable?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So it turns out that re-Scalarizing breaks something (the double backward?). 😓

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cc: @wanchaol added the scalar to float conversion.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also, the scalar to float conversion only happens on the second pass, when allow_conversions is true, if there is an op defined for Scalar than that will be matched and there won't be a conversion.

Copy link
Contributor

@ailzhang ailzhang Apr 17, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@t-vi Actually as @eellison pointed out offline, we actually can get rid of the scalar -> float conversion entirely.
For example this works for me.

        def pow_0(self,
                  exponent: number):

Note that we don't expose number in torchscript but we CAN compile it! :D With this we can easily get rid of current symbolic_variable.h and c10::ReplaceAll(schema_str, "Scalar", "float");.
Huge thanks to @eellison who pointed this out!
Let us know if this fixes your problem :D

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ha. That works (I think)! And I've been poking around much too long. Awesome.
Thanks so much @ailzhang and @eellison !

So my understanding is that with this, one would only have definitions that match the operator schemas. Could we actually check that? For now I left a fallback in but converted those replacements that were hit by a test_jit.py run.

@t-vi t-vi changed the title pow scalar exponent / base autodiff, fusion [WIP] pow scalar exponent / base autodiff, fusion Apr 17, 2019

auto sym_script_it = schema_to_graphs.find(schema_str);

if (sym_script_it == schema_to_graphs.end()) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yea this if should be dropped before merging.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So I removed it after double checking by means of calling sig(schema_string) and seeing whether it throws here:

auto schema_string = overloadedSchemaString(actual_schema);
schema_to_graphs[schema_string] = std::move(pair);

Copy link
Contributor

@ailzhang ailzhang left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM! Let me know when you are done with changing and want to merge it.

@t-vi
Copy link
Collaborator Author

t-vi commented Apr 18, 2019 via email

@ailzhang ailzhang changed the title [WIP] pow scalar exponent / base autodiff, fusion pow scalar exponent / base autodiff, fusion Apr 18, 2019
Copy link
Contributor

@facebook-github-bot facebook-github-bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ailzhang has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

@facebook-github-bot
Copy link
Contributor

@ailzhang merged this pull request in b9291f5.

zhangguanheng66 pushed a commit to zhangguanheng66/pytorch that referenced this pull request May 6, 2019
Summary:
Fixes: pytorch#19253

Fixing pow(Tensor, float) is straightforward.
The breakage for pow(float, Tensor) is a bit more subtle to trigger, and fixing needs `torch.log` (`math.log` didn't work) from the newly merged pytorch#19115  (Thanks ngimel for pointing out this has landed.)
Pull Request resolved: pytorch#19324

Differential Revision: D15003531

Pulled By: ailzhang

fbshipit-source-id: 8b22138fa27a43806b82886fb3a7b557bbb5a865
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

oncall: jit Add this issue/PR to JIT oncall triage queue open source

Projects

None yet

Development

Successfully merging this pull request may close these issues.

torch.pow() in a script module produces an error

5 participants