-
Notifications
You must be signed in to change notification settings - Fork 25.7k
Fixes error when too many parameters are passed to fused cuda kernel #18063
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
when it exceeds nvrtc's limit. Bug fix for issue 15043.
when it exceeds nvrtc's limit. Bug fix for issue 15043.
Add more comments for the argu limit.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We should avoid generating fusion groups this large altogether. This patch fixes the bug but silently disables an optimization in a case which is legitimate.
|
|
||
| class TestJit(JitTestCase): | ||
| @unittest.skipIf(not RUN_CUDA, "requires CUDA") | ||
| def test_large_nbr_kernel_args(self): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This test will be very slow. Can we just prepare a graph that has ~200 ops that would normally get fused perfectly, but we can't because that's too much for a single kernel?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the comment. The number of kernel arguments and the number of ops in a fusion group could be related but not always so, and the cause of this bug is the former. While we may want to limit the number of ops in a fusion group, I consider that's a separate issue to address. In fact, in the original test from the issue, the offending kernel itself has only about 10 ops in the fusion group, but because of a FusedConcat, it has a large number of live-ins, which exceeds the limit. Hence, we need to constrain the # of arguments, not the # of ops, for this specific issue.
It's true that the fix gives up fusion on the entire FusionGroup if the number of args is over the limit. One could argue to limit the FusionGroup by # of args in GraphFuser. In fact, I did prototype that. But the trade-off isn't always straightforward. First, since this limit is due to the cuda path, a change in GraphFuser could affect other devices. In addition, it's possible that if one stops a FusionGroup by estimating the number of possible arguments during fusion, further fusion could actually reduce the number of live-in/out arguments in the FusionGroup back below the limit. To keep this PR simple, I didn't continue that route. As I have learned, GraphFuser is sometime too aggressive such that we bail out to the fall back path in the Executor too often than we'd like to see. When we address the aggressiveness of GraphFusion (including constraining # of ops), we will reduce such possibility in giving up an entire FusionGroup by taking a preventive measure (as opposed to a reactive one).
Lastly on reducing the test case, I largely cut down the test and testing time from the original test. I again tried more just now, but haven't made much progress. The challenge is that I need to maintain the number of arguments (> 130) and also make sure the GraphFuser fuses it. Currently, the body of each iteration is
b = input[i] * 2
output.append(b)
To keep each iteration generating one new arg to a later FusedConcat, I keep the indexing form "input[i]", which is translated to 3 ops (1 select op and 2 constants: axis and unrolled index). To trigger fusion, I need to add a pw op (one mul op). 4 ops * 130 unrolled iterations are more than a couple of hundred ops. This test does take more time than some of the really small tests, but it doesn't seem totally out of line and it's off on the non-cuda path. I haven't found a good way to further reduce the test, probably because still being new to PyT, but I'm open to any suggestion.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ok, my bad. The test does look somewhat simple, but please don't call it GRU in this case.
Also, does your test actually even trigger the fuser? You never run the traced function, so the code probably doesn't even get compiled.
Finally, regarding the tradeoffs, I think those are relatively simple. I doubt there are other devices where fusing more than 128 arguments into a single kernel would be beneficial. Finally, emitting many kernels is always superior to emitting no kernels.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Right. It's no longer the original GRU in the issue. I had a clarifying comment, but I will rename the class to avoid confusion.
Line 607 traced_gru = torch.jit.trace(gru, (input)) calls jit trace, so compilation, including fuser, does get invoked.
Ok. I can re-introduce the logic to estimate and track the # of arguments during fusion and stop growing if the limit is exceeded. I'm at GTC much of the week, and shall get back on this later this week. Thanks.
|
Revision done according to the previous comments. Rename the class name in the test. Estimate the # of kernel arguments during fusion, and bail out if the limit may be exceeded. |
|
ci/circleci: binary_linux_conda_2.7_cpu_build — Your tests failed on CircleCI Mar 22 01:52:03 /opt/conda/conda-bld/pytorch-nightly-cpu_1553219249856/work/third_party/ideep/mkl-dnn/src/cpu/ref_rnn.cpp:979:50: error: ‘void cblas_sgemm_free(float*)’ is deprecated (declared at /opt/conda/conda-bld/pytorch-nightly-cpu_1553219249856/_h_env_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_pla/include/mkl_cblas.h:804) [-Werror=deprecated-declarations] The above test failure is not caused by this commit. It also failed to other devs, and it appears the pytorch/pytorch repo has this issue already. |
| any_fused = true; | ||
| auto maybe_group = tryFuse(fused_cat, input); | ||
| if( !maybe_group ) { | ||
| continue; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
at this point any_fused is already set to true, even if no concat fusion happens, which will throw off subsequent checks.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks. Moved down "any_fused = true;" after all early bail-out are done.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please adjust canFuseWithConcat instead. The point here was to make sure that this function is sufficient for tryFuse to succeed, which is checked by the assert below.
| ((before_check->inputs().size() + before_check->outputs().size() + | ||
| producer->node()->inputs().size() + producer->node()->outputs().size()) | ||
| <= fusion_kernel_args_limit); | ||
|
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This seems unnecessary as the same check will be performed a few lines later in tryFuse, also, if you are trying to fuse concat to something else than a FusionGroup, this function will return true (line 1092) which does not look right.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Remove this check and let the one in tryFuse() take care of this.
Remove the redundant check in canFuseWithConcat(), and the check in tryFuse() can cover it.
test/test_jit.py
Outdated
| def __init__(self, input_size, seq_len): | ||
| super(Recurrence, self).__init__() | ||
| self.input_size = input_size | ||
| self.batch_first = True |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
you don't need batch_first, seq_len and input_size as module attributes for this test
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
|
|
||
| # Main loop | ||
| output = [] | ||
| for i in range(self.seq_len): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can use input.size(0) instead of seq_len
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Using input.size(0) leads to the following warning. Since this is a test, it should be ok. Change done.
test_large_nbr_kernel_args (main.TestJit) ... test/test_jit.py:581: TracerWarning: Converting a tensor to a Python index might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
for i in range(input.size(0)):
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Put seq_len back as a module attribute to avoid the warning.
| // but this is checked later | ||
| return isFusable(node->inputs()[0]->node()); | ||
| } | ||
| if( (node->inputs().size() + node->outputs().size()) > |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This check is not needed, because everywhere after isFusable tryFuse will be called, so only check in tryFuse can be left.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This check is inside isFusableMap(), which is called by isFusable() and one other place in tryToMoveChunk(). In the latter case, there is no immediate call to tryFuse() inside the function. It looks we should keep the guard. No?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
tryToMoveChunk is not making any fusions by itself, thus should not blacklist any nodes. Any changes to FusionGroup are caused by tryFuse, thus checking (or estimating, as the case may be) number of inputs/outputs only in tryFuse is better.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ok. Thanks for checking. This check in isFusableMap() is removed. Just pushed a new commit. CI is starting.
| if (!node->is_constant(attr::dim)) | ||
| return false; | ||
|
|
||
| Node* list_construct = node->namedInput(attr::tensors)->node(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
list_construct is the same as tensors_node a few lines below, you can move tensors_node line here and reuse it.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done. Will push a commit soon after testing. Thanks.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@soumith is landing this pull request. If you are a Facebook employee, you can view this diff on Phabricator.
| continue; | ||
| } | ||
| any_fused = true; | ||
| AT_ASSERT(maybe_group && maybe_group == fused_cat); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The first part of this assert is meaningless now
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ok. Revision is done in the latest commit. Place a limit check in canFuseWithConcat() to restore the original code sequence around tryFuse() call in fuseConcats().
| return false; | ||
|
|
||
| auto tensors_node = node->namedInput(attr::tensors)->node(); | ||
| if( (tensors_node->inputs().size() + node->outputs().size()) > |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
FWIW there's always only a single output in here, so we don't have to account for that
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Since an output could still be an argument to the fused kernel generated later, I didn't change this part in the latest commit.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
BTW we should assert that this never happens, because now we have guards for it. If it does, then someone messed up somewhere.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Would you suggest a specific place and assert to add?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Right here. This is the condition that should never happen in this very place
| any_fused = true; | ||
| auto maybe_group = tryFuse(fused_cat, input); | ||
| if( !maybe_group ) { | ||
| continue; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please adjust canFuseWithConcat instead. The point here was to make sure that this function is sufficient for tryFuse to succeed, which is checked by the assert below.
sequence around tryFuse() call in fuseConcats().
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@soumith is landing this pull request. If you are a Facebook employee, you can view this diff on Phabricator.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please just change the condition inside the fuser (not the fuser pass) to be a hard assertion.
|
@pytorchbot retest this please. |
| // Have checked the limit at graph_fuser. Assert nothing else changing that. | ||
| AT_ASSERT((flat_inputs.size() + flat_outputs.size()) <= | ||
| fusion_kernel_args_limit); | ||
|
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Now that this has come full circle, and compileKernel can no longer return nullopt, can you please change return type of this function, and remove nullopt handling logic in the executor?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I was thinking to keep the more flexible interface able to handle the case of nullopt. But I can remove that if so desired.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The limit check in compileKernel() has been turned to an assert, and the returned value of possible nullopt from compileKernel() has also been removed in the latest commits.
|
@pytorchbot rebase this please |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@soumith is landing this pull request. If you are a Facebook employee, you can view this diff on Phabricator.
…ytorch#18063) Summary: Bug fix for pytorch#15043, where a large fusion in JIT with a large number of kernel arguments, which exceeds the limit allowed by nvrtc on a cuda device. The fix is to check the number of arguments before a cuda kernel is generated. If the number exceeds the limit, take the runFallBack() path. Add a reduced test from the original issue to keep the test time low. The test would fail without this fix. Pull Request resolved: pytorch#18063 Differential Revision: D14691401 Pulled By: soumith fbshipit-source-id: b98829bc89ed7724e91eda82ae3a5a1151af721a
Bug fix for #15043, where a large fusion in JIT with a large number of kernel arguments, which exceeds the limit allowed by nvrtc on a cuda device.
The fix is to check the number of arguments before a cuda kernel is generated. If the number exceeds the limit, take the runFallBack() path.
Add a reduced test from the original issue to keep the test time low. The test would fail without this fix.