- 
                Notifications
    You must be signed in to change notification settings 
- Fork 25.7k
Get more fusion after autodiff uses SumToSize #14957
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think the removal of SumToSize nodes is happening way too early. Basically, you shouldn’t think of FusionGroups as graphs that have already been fused and will conform to those semantics, but graphs eligible to be fused. That means that we still want to preserve the original semantics of the code, because it might turn out that our fusion guesses were wrong, and will end up running a deoptimized version of the original code. Instead, we should allow putting them in FusionGroups and simply remove them right before a kernel is compiled (once we know that the fusion is valid, etc.).
Finally, marking those nodes as fusible is a bad idea, because the only reason why you might put them in a fusion group is because you are certain that it will help you perform more fusions. That should be checked and processed similarly to how we deal with rearranging chunk nodes.
| Thanks for your comments Adam! 
 | 
| Hmh. I need to rebase. | 
We don't support reductions yet, but simply decomposing batch_norm into a kernel that computes the stats, and the fusing everything else with ReLU and following pointwise ops provides nice speedups. Note that this is only limited to inference mode for now, because we don't support convolutions and batch norm in AD, so the fuser isn't applied to those parts.
That makes that definition of a "fusable node" much simpler, as we don't need to keep considering whether something has to be an "exit node" at every step. The fuser now tries to maximize the pointwise fusions first, and proceeds to prepending chunks and appending concats only once a fix point is reached. This patch not only makes the fuser much simpler to reason about, making it siginifcantly easier to implement features like SumToSize fusion, to improve performance of derivative graphs.
We don't support reductions yet, but simply decomposing batch_norm into a kernel that computes the stats, and the fusing everything else with ReLU and following pointwise ops provides nice speedups. Note that this is only limited to inference mode for now, because we don't support convolutions and batch norm in AD, so the fuser isn't applied to those parts.
That makes that definition of a "fusable node" much simpler, as we don't need to keep considering whether something has to be an "exit node" at every step. The fuser now tries to maximize the pointwise fusions first, and proceeds to prepending chunks and appending concats only once a fix point is reached. This patch not only makes the fuser much simpler to reason about, making it siginifcantly easier to implement features like SumToSize fusion, to improve performance of derivative graphs.
| Thanks @ngimel, for raising this. I'll see how to fix that. My understanding is that we would need to deduplicate it for the kernel, but not for the _grad_sum_to_size application after running the fused kernel. That in turn means we have different outputs for the fused kernel vs. the fusion group. | 
After squeezing out the _grad_sum_to_size during kernel compilation,
we may end up with duplicate outputs.
For example example, the backward of
    def fn1(x,y,z):
        a = x+y+z
        return torch.sigmoid(a)
has that.
Thank you @ngimel for noting and providing the example!
    | So I added the output deduplication in the fuser and a test using @ngimel 's example (thanks again!). | 
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this should still be improved before we land. First, it can fail when one encounters some non-basic fused nodes, and there's no fallback in that case. Also, it would be good to improve the quality of fusability checks for aten::_grad_sum_to_size to avoid fusing them unnecessarily.
| return false; | ||
| return node->kind() == prim::FusionGroup || isSimpleMap(node); | ||
| return node->kind() == prim::FusionGroup || | ||
| node->kind() == aten::_grad_sum_to_size || isSimpleMap(node); | 
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'd really like us to avoid putting _grad_sum_to_size nodes in fusion groups unnecessarily. We should either have stronger checks for them (that adding them would in fact help us fuse more), or we should have a postprocessing pass that e.g. will move them out of the group if they are applied to inputs, or create outputs.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for going through this!
So postprocessing or checking if the input of _grad_sum_to_size is fusible? Do you have a preference? (Edit: Would just checking isFusable(node->inputs()[0]->node()) be enough?)
Would the "create outputs" comment be mitigated by the deduplication in the fuser itself?
Personally, I would envision that one would move cases where the sumtosizes of the outputs are "ascending" (i.e. you can sort the dimensions in a way that every tensor only has the summations last) into the kernel itself.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
So I now do this. This probably means we do not fuse some cases where we would like to - e.g. the milstm backward could be such a case, but we are doing better than before and maybe we just are conservative.
| @apaszke I think I acted on your comments (though I have the feeling that there might be a nicer way to cover the two uses of  | 
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks great! Some minor comments, but should be good to land
| return at::nullopt; | ||
| } | ||
| if (producer->node()->kind() == aten::_grad_sum_to_size && | ||
| consumer->kind() == prim::FusionGroup) { | 
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we assert that consumer is a fusion group? If I understand correctly this is the only case possible by design, but having a check like this would mean that this case will simply get skipped if it wasn't the case.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think the case where the consumer isn't a fusion group, i.e. a new fusion group will be started with a _grad_sum_to_size for the output, is legitimate, and then we would have the empty set to check.
The main alternative I see would be to refuse to fuse here, but that would only result in us having two _grad_sum_to_sizes in a row (one pushed from inside the fusion group and one that we didn't fuse because it was at the output) during execution.
| So the three CI failures would appear to be unrelated to my changes (flake8 in onnx/symbolic.py which I don't think I touch, python to old to download mypy, UtilsNMSTest.GPUEqualsCPUCorrectnessTest - this is perhaps the one I know the least about). | 
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@zou3519 has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.
Summary: Here is a fresh attempt at getting some fusion back in autodiff-generated graphs in the presence of SumToSize. - The sum to size operator is now `aten::_grad_sum_to_size` to allow symbolic script differentiation (and that in turn would need to use this in place of sum_to_size to signal that it strictly operates on gradients). This is also used in the autodiff code, replacing `prim::SumToSize`. - `_grad_sum_to_size` is now fusable, `cat`s - which are fused afterwards thanks to Adam's simplification of the code - are only fused if there is no `_grad_sum_to_size` in the fusion group. - I push the `_grad_sum_to_size` out of the the fusion group when compiling and record the desired summations in the KernelSpec. The reasoning is the following: - As the autodiff is a repeated applicaiton of the chain rule, we always have the pattern `grad_in = mm(A, grad_out)`, with A often diagonal for cases interesting to the fuser, whence it is `grad_in = a * grad_out` (a pointwise multiplication). We know that only `grad_out` may have AutodiffGradSumToSize applied, so we can commute AutodiffGradSumToSize with the `mul` (and `div` and `neg` are of similar origin). - For `type_as` the gradient might be giving the type, so just skip SumToSize, - `add` (which was inserted as `prim::AutogradAdd`) adding gradients when the forward used the same value in several places. This is non-broadcasting, so we know that the two arguments would have the same sizes as inputs - which is good so we don't have to do bookkeeping of the two parts. Details: - During fusion, the Tensor arguments are always kept as the first parameters of the fusion group to accomodate indexing assumptions in the fuser. - The rewriting of the fusion group to record the necessary output transformation and eliminate `_grad_sum_to_size` from the fusion group is now in the fuser compile step. - In the execution step, the arguments are split into Tensor / Non-Tensor and the non-tensor args are mostly forgotten about except for doing `sum_to_size` at the end. This would want to be improved if/when we fuse nonconstant scalar arguments. - In a number of places in the fuser, the non-Tensor arguments to the fusion group needed to be ignored. Thank you, apaszke for the insightful discussion. All bad ideas and errors are my own. Pull Request resolved: pytorch/pytorch#14957 Differential Revision: D13888173 Pulled By: zou3519 fbshipit-source-id: 071992c876e8b845f2b3e6329ae03a835d39a0ea
| This is exciting! I'll have to rerun my benchmarks now! | 
Here is a fresh attempt at getting some fusion back in autodiff-generated graphs in the presence of SumToSize.
aten::_grad_sum_to_sizeto allow symbolic script differentiation (and that in turn would need to use this in place of sum_to_size to signal that it strictly operates on gradients). This is also used in the autodiff code, replacingprim::SumToSize._grad_sum_to_sizeis now fusable,cats - which are fused afterwards thanks to Adam's simplification of the code - are only fused if there is no_grad_sum_to_sizein the fusion group._grad_sum_to_sizeout of the the fusion group when compiling and record the desired summations in the KernelSpec. The reasoning is the following:grad_in = mm(A, grad_out), with A often diagonal for cases interesting to the fuser, whence it isgrad_in = a * grad_out(a pointwise multiplication). We know that onlygrad_outmay have AutodiffGradSumToSize applied, so we can commute AutodiffGradSumToSize with themul(anddivandnegare of similar origin).type_asthe gradient might be giving the type, so just skip SumToSize,add(which was inserted asprim::AutogradAdd) adding gradients when the forward used the same value in several places. This is non-broadcasting, so we know that the two arguments would have the same sizes as inputs - which is good so we don't have to do bookkeeping of the two parts.Details:
_grad_sum_to_sizefrom the fusion group is now in the fuser compile step.sum_to_sizeat the end. This would want to be improved if/when we fuse nonconstant scalar arguments.Thank you, @apaszke for the insightful discussion. All bad ideas and errors are my own.