KEMBAR78
Get more fusion after autodiff uses SumToSize by t-vi · Pull Request #14957 · pytorch/pytorch · GitHub
Skip to content

Conversation

@t-vi
Copy link
Collaborator

@t-vi t-vi commented Dec 9, 2018

Here is a fresh attempt at getting some fusion back in autodiff-generated graphs in the presence of SumToSize.

  • The sum to size operator is now aten::_grad_sum_to_size to allow symbolic script differentiation (and that in turn would need to use this in place of sum_to_size to signal that it strictly operates on gradients). This is also used in the autodiff code, replacing prim::SumToSize.
  • _grad_sum_to_size is now fusable, cats - which are fused afterwards thanks to Adam's simplification of the code - are only fused if there is no _grad_sum_to_size in the fusion group.
  • I push the _grad_sum_to_size out of the the fusion group when compiling and record the desired summations in the KernelSpec. The reasoning is the following:
    • As the autodiff is a repeated applicaiton of the chain rule, we always have the pattern grad_in = mm(A, grad_out), with A often diagonal for cases interesting to the fuser, whence it is grad_in = a * grad_out (a pointwise multiplication). We know that only grad_out may have AutodiffGradSumToSize applied, so we can commute AutodiffGradSumToSize with the mul (and div and neg are of similar origin).
    • For type_as the gradient might be giving the type, so just skip SumToSize,
    • add (which was inserted as prim::AutogradAdd) adding gradients when the forward used the same value in several places. This is non-broadcasting, so we know that the two arguments would have the same sizes as inputs - which is good so we don't have to do bookkeeping of the two parts.

Details:

  • During fusion, the Tensor arguments are always kept as the first parameters of the fusion group to accomodate indexing assumptions in the fuser.
  • The rewriting of the fusion group to record the necessary output transformation and eliminate _grad_sum_to_size from the fusion group is now in the fuser compile step.
  • In the execution step, the arguments are split into Tensor / Non-Tensor and the non-tensor args are mostly forgotten about except for doing sum_to_size at the end. This would want to be improved if/when we fuse nonconstant scalar arguments.
  • In a number of places in the fuser, the non-Tensor arguments to the fusion group needed to be ignored.

Thank you, @apaszke for the insightful discussion. All bad ideas and errors are my own.

@facebook-github-bot facebook-github-bot added the oncall: jit Add this issue/PR to JIT oncall triage queue label Dec 9, 2018
Copy link
Contributor

@apaszke apaszke left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the removal of SumToSize nodes is happening way too early. Basically, you shouldn’t think of FusionGroups as graphs that have already been fused and will conform to those semantics, but graphs eligible to be fused. That means that we still want to preserve the original semantics of the code, because it might turn out that our fusion guesses were wrong, and will end up running a deoptimized version of the original code. Instead, we should allow putting them in FusionGroups and simply remove them right before a kernel is compiled (once we know that the fusion is valid, etc.).

Finally, marking those nodes as fusible is a bad idea, because the only reason why you might put them in a fusion group is because you are certain that it will help you perform more fusions. That should be checked and processed similarly to how we deal with rearranging chunk nodes.

@t-vi
Copy link
Collaborator Author

t-vi commented Dec 10, 2018

Thanks for your comments Adam!

  • I'll rename the prim::GradSumToSize.
  • So I'll move the graph rewriting into the fuser codegen.
  • For "when to fuse SumToSize", would it be OK to put them into the fusion group if there aren't any FusedConcat nodes in there?
    This would mean that we might end up with SumToSize at the top of the fusion group which we would undo before the fusion, but I'm a bit weary that GraphFuser.run will get considerably more complicated if we split out the scan phase as done for chunk.

@t-vi
Copy link
Collaborator Author

t-vi commented Dec 13, 2018

Hmh. I need to rebase.
So I think I'm not fusing sumtosize any more when concat is close.
I'm not as sure about the "when to relocate sumtosize": If I move that into kernel generation, is it still safe to move the sumtosize to outside the fusion group? I'll try that next, but I'm still a bit sceptical about it.

apaszke and others added 19 commits December 30, 2018 18:19
We don't support reductions yet, but simply decomposing batch_norm
into a kernel that computes the stats, and the fusing everything else
with ReLU and following pointwise ops provides nice speedups.

Note that this is only limited to inference mode for now, because we
don't support convolutions and batch norm in AD, so the fuser isn't
applied to those parts.
That makes that definition of a "fusable node" much simpler,
as we don't need to keep considering whether something has to be an
"exit node" at every step. The fuser now tries to maximize the
pointwise fusions first, and proceeds to prepending chunks and appending
concats only once a fix point is reached.

This patch not only makes the fuser much simpler to reason about,
making it siginifcantly easier to implement features like SumToSize
fusion, to improve performance of derivative graphs.
We don't support reductions yet, but simply decomposing batch_norm
into a kernel that computes the stats, and the fusing everything else
with ReLU and following pointwise ops provides nice speedups.

Note that this is only limited to inference mode for now, because we
don't support convolutions and batch norm in AD, so the fuser isn't
applied to those parts.
That makes that definition of a "fusable node" much simpler,
as we don't need to keep considering whether something has to be an
"exit node" at every step. The fuser now tries to maximize the
pointwise fusions first, and proceeds to prepending chunks and appending
concats only once a fix point is reached.

This patch not only makes the fuser much simpler to reason about,
making it siginifcantly easier to implement features like SumToSize
fusion, to improve performance of derivative graphs.
@t-vi
Copy link
Collaborator Author

t-vi commented Jan 17, 2019

Thanks @ngimel, for raising this. I'll see how to fix that. My understanding is that we would need to deduplicate it for the kernel, but not for the _grad_sum_to_size application after running the fused kernel. That in turn means we have different outputs for the fused kernel vs. the fusion group.

t-vi added 3 commits January 19, 2019 11:02
After squeezing out the _grad_sum_to_size during kernel compilation,
we may end up with duplicate outputs.
For example example, the backward of
    def fn1(x,y,z):
        a = x+y+z
        return torch.sigmoid(a)
has that.
Thank you @ngimel for noting and providing the example!
@t-vi
Copy link
Collaborator Author

t-vi commented Jan 19, 2019

So I added the output deduplication in the fuser and a test using @ngimel 's example (thanks again!).

Copy link
Contributor

@apaszke apaszke left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this should still be improved before we land. First, it can fail when one encounters some non-basic fused nodes, and there's no fallback in that case. Also, it would be good to improve the quality of fusability checks for aten::_grad_sum_to_size to avoid fusing them unnecessarily.

return false;
return node->kind() == prim::FusionGroup || isSimpleMap(node);
return node->kind() == prim::FusionGroup ||
node->kind() == aten::_grad_sum_to_size || isSimpleMap(node);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd really like us to avoid putting _grad_sum_to_size nodes in fusion groups unnecessarily. We should either have stronger checks for them (that adding them would in fact help us fuse more), or we should have a postprocessing pass that e.g. will move them out of the group if they are applied to inputs, or create outputs.

Copy link
Collaborator Author

@t-vi t-vi Jan 26, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for going through this!

So postprocessing or checking if the input of _grad_sum_to_size is fusible? Do you have a preference? (Edit: Would just checking isFusable(node->inputs()[0]->node()) be enough?)
Would the "create outputs" comment be mitigated by the deduplication in the fuser itself?
Personally, I would envision that one would move cases where the sumtosizes of the outputs are "ascending" (i.e. you can sort the dimensions in a way that every tensor only has the summations last) into the kernel itself.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So I now do this. This probably means we do not fuse some cases where we would like to - e.g. the milstm backward could be such a case, but we are doing better than before and maybe we just are conservative.

@t-vi
Copy link
Collaborator Author

t-vi commented Jan 27, 2019

@apaszke I think I acted on your comments (though I have the feeling that there might be a nicer way to cover the two uses of trackSingleGradSumToSizeToOutputs, but I don't really have an idea how).

Copy link
Contributor

@apaszke apaszke left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks great! Some minor comments, but should be good to land

return at::nullopt;
}
if (producer->node()->kind() == aten::_grad_sum_to_size &&
consumer->kind() == prim::FusionGroup) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we assert that consumer is a fusion group? If I understand correctly this is the only case possible by design, but having a check like this would mean that this case will simply get skipped if it wasn't the case.

Copy link
Collaborator Author

@t-vi t-vi Jan 30, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the case where the consumer isn't a fusion group, i.e. a new fusion group will be started with a _grad_sum_to_size for the output, is legitimate, and then we would have the empty set to check.
The main alternative I see would be to refuse to fuse here, but that would only result in us having two _grad_sum_to_sizes in a row (one pushed from inside the fusion group and one that we didn't fuse because it was at the output) during execution.

@t-vi
Copy link
Collaborator Author

t-vi commented Jan 30, 2019

So the three CI failures would appear to be unrelated to my changes (flake8 in onnx/symbolic.py which I don't think I touch, python to old to download mypy, UtilsNMSTest.GPUEqualsCPUCorrectnessTest - this is perhaps the one I know the least about).

Copy link
Contributor

@facebook-github-bot facebook-github-bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@zou3519 has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

zdevito pushed a commit to zdevito/ATen that referenced this pull request Jan 31, 2019
Summary:
Here is a fresh attempt at getting some fusion back in autodiff-generated graphs in the presence of SumToSize.

- The sum to size operator is now  `aten::_grad_sum_to_size` to allow symbolic script differentiation (and that in turn would need to use this in place of sum_to_size to signal that it strictly operates on gradients). This is also used in the autodiff code, replacing `prim::SumToSize`.
- `_grad_sum_to_size` is now fusable, `cat`s - which are fused afterwards thanks to Adam's simplification of the code - are only fused if there is no `_grad_sum_to_size` in the fusion group.
- I push the `_grad_sum_to_size` out of the the fusion group when compiling and record the desired summations in the KernelSpec. The reasoning is the following:
  - As the autodiff is a repeated applicaiton of the chain rule, we always have the pattern `grad_in = mm(A, grad_out)`,  with A often diagonal for cases interesting to the fuser, whence it is `grad_in = a * grad_out` (a pointwise multiplication). We know that only `grad_out` may have AutodiffGradSumToSize applied, so we can commute AutodiffGradSumToSize with the `mul` (and `div` and `neg` are of similar origin).
  - For `type_as` the gradient might be giving the type, so just skip SumToSize,
  - `add` (which was inserted as `prim::AutogradAdd`) adding gradients when the forward used the same value in several places. This is non-broadcasting, so we know that the two arguments would have the same sizes as inputs - which is good so we don't have to do bookkeeping of the two parts.

Details:
- During fusion, the Tensor arguments are always kept as the first parameters of the fusion group to accomodate indexing assumptions in the fuser.
- The rewriting of the fusion group to record the necessary output transformation and eliminate `_grad_sum_to_size` from the fusion group is now in the fuser compile step.
- In the execution step, the arguments are split into Tensor / Non-Tensor and the non-tensor args are mostly forgotten about except for doing `sum_to_size` at the end. This would want to be improved if/when we fuse nonconstant scalar arguments.
- In a number of places in the fuser, the non-Tensor arguments to the fusion group needed to be ignored.

Thank you, apaszke for the insightful discussion. All bad ideas and errors are my own.
Pull Request resolved: pytorch/pytorch#14957

Differential Revision: D13888173

Pulled By: zou3519

fbshipit-source-id: 071992c876e8b845f2b3e6329ae03a835d39a0ea
@apaszke
Copy link
Contributor

apaszke commented Feb 1, 2019

This is exciting! I'll have to rerun my benchmarks now!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

oncall: jit Add this issue/PR to JIT oncall triage queue open source

Projects

None yet

Development

Successfully merging this pull request may close these issues.

7 participants