KEMBAR78
Fixes error when too many parameters are passed to fused cuda kernel by royju · Pull Request #18063 · pytorch/pytorch · GitHub
Skip to content

Conversation

@royju
Copy link
Contributor

@royju royju commented Mar 15, 2019

Bug fix for #15043, where a large fusion in JIT with a large number of kernel arguments, which exceeds the limit allowed by nvrtc on a cuda device.
The fix is to check the number of arguments before a cuda kernel is generated. If the number exceeds the limit, take the runFallBack() path.
Add a reduced test from the original issue to keep the test time low. The test would fail without this fix.

@facebook-github-bot facebook-github-bot added the oncall: jit Add this issue/PR to JIT oncall triage queue label Mar 15, 2019
Copy link
Contributor

@apaszke apaszke left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should avoid generating fusion groups this large altogether. This patch fixes the bug but silently disables an optimization in a case which is legitimate.


class TestJit(JitTestCase):
@unittest.skipIf(not RUN_CUDA, "requires CUDA")
def test_large_nbr_kernel_args(self):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This test will be very slow. Can we just prepare a graph that has ~200 ops that would normally get fused perfectly, but we can't because that's too much for a single kernel?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the comment. The number of kernel arguments and the number of ops in a fusion group could be related but not always so, and the cause of this bug is the former. While we may want to limit the number of ops in a fusion group, I consider that's a separate issue to address. In fact, in the original test from the issue, the offending kernel itself has only about 10 ops in the fusion group, but because of a FusedConcat, it has a large number of live-ins, which exceeds the limit. Hence, we need to constrain the # of arguments, not the # of ops, for this specific issue.
It's true that the fix gives up fusion on the entire FusionGroup if the number of args is over the limit. One could argue to limit the FusionGroup by # of args in GraphFuser. In fact, I did prototype that. But the trade-off isn't always straightforward. First, since this limit is due to the cuda path, a change in GraphFuser could affect other devices. In addition, it's possible that if one stops a FusionGroup by estimating the number of possible arguments during fusion, further fusion could actually reduce the number of live-in/out arguments in the FusionGroup back below the limit. To keep this PR simple, I didn't continue that route. As I have learned, GraphFuser is sometime too aggressive such that we bail out to the fall back path in the Executor too often than we'd like to see. When we address the aggressiveness of GraphFusion (including constraining # of ops), we will reduce such possibility in giving up an entire FusionGroup by taking a preventive measure (as opposed to a reactive one).
Lastly on reducing the test case, I largely cut down the test and testing time from the original test. I again tried more just now, but haven't made much progress. The challenge is that I need to maintain the number of arguments (> 130) and also make sure the GraphFuser fuses it. Currently, the body of each iteration is
b = input[i] * 2
output.append(b)
To keep each iteration generating one new arg to a later FusedConcat, I keep the indexing form "input[i]", which is translated to 3 ops (1 select op and 2 constants: axis and unrolled index). To trigger fusion, I need to add a pw op (one mul op). 4 ops * 130 unrolled iterations are more than a couple of hundred ops. This test does take more time than some of the really small tests, but it doesn't seem totally out of line and it's off on the non-cuda path. I haven't found a good way to further reduce the test, probably because still being new to PyT, but I'm open to any suggestion.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok, my bad. The test does look somewhat simple, but please don't call it GRU in this case.

Also, does your test actually even trigger the fuser? You never run the traced function, so the code probably doesn't even get compiled.

Finally, regarding the tradeoffs, I think those are relatively simple. I doubt there are other devices where fusing more than 128 arguments into a single kernel would be beneficial. Finally, emitting many kernels is always superior to emitting no kernels.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Right. It's no longer the original GRU in the issue. I had a clarifying comment, but I will rename the class to avoid confusion.

Line 607 traced_gru = torch.jit.trace(gru, (input)) calls jit trace, so compilation, including fuser, does get invoked.

Ok. I can re-introduce the logic to estimate and track the # of arguments during fusion and stop growing if the limit is exceeded. I'm at GTC much of the week, and shall get back on this later this week. Thanks.

@royju
Copy link
Contributor Author

royju commented Mar 22, 2019

Revision done according to the previous comments. Rename the class name in the test. Estimate the # of kernel arguments during fusion, and bail out if the limit may be exceeded.

@royju
Copy link
Contributor Author

royju commented Mar 22, 2019

ci/circleci: binary_linux_conda_2.7_cpu_build — Your tests failed on CircleCI

Mar 22 01:52:03 /opt/conda/conda-bld/pytorch-nightly-cpu_1553219249856/work/third_party/ideep/mkl-dnn/src/cpu/ref_rnn.cpp:979:50: error: ‘void cblas_sgemm_free(float*)’ is deprecated (declared at /opt/conda/conda-bld/pytorch-nightly-cpu_1553219249856/_h_env_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_pla/include/mkl_cblas.h:804) [-Werror=deprecated-declarations]
Mar 22 01:52:03 cblas_sgemm_free(weights(i, j, k));
Mar 22 01:52:03 ^

The above test failure is not caused by this commit. It also failed to other devs, and it appears the pytorch/pytorch repo has this issue already.

any_fused = true;
auto maybe_group = tryFuse(fused_cat, input);
if( !maybe_group ) {
continue;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

at this point any_fused is already set to true, even if no concat fusion happens, which will throw off subsequent checks.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks. Moved down "any_fused = true;" after all early bail-out are done.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please adjust canFuseWithConcat instead. The point here was to make sure that this function is sufficient for tryFuse to succeed, which is checked by the assert below.

((before_check->inputs().size() + before_check->outputs().size() +
producer->node()->inputs().size() + producer->node()->outputs().size())
<= fusion_kernel_args_limit);

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This seems unnecessary as the same check will be performed a few lines later in tryFuse, also, if you are trying to fuse concat to something else than a FusionGroup, this function will return true (line 1092) which does not look right.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Remove this check and let the one in tryFuse() take care of this.

Remove the redundant check in canFuseWithConcat(), and the check
in tryFuse() can cover it.
test/test_jit.py Outdated
def __init__(self, input_size, seq_len):
super(Recurrence, self).__init__()
self.input_size = input_size
self.batch_first = True
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you don't need batch_first, seq_len and input_size as module attributes for this test

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.


# Main loop
output = []
for i in range(self.seq_len):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can use input.size(0) instead of seq_len

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Using input.size(0) leads to the following warning. Since this is a test, it should be ok. Change done.
test_large_nbr_kernel_args (main.TestJit) ... test/test_jit.py:581: TracerWarning: Converting a tensor to a Python index might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
for i in range(input.size(0)):

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Put seq_len back as a module attribute to avoid the warning.

// but this is checked later
return isFusable(node->inputs()[0]->node());
}
if( (node->inputs().size() + node->outputs().size()) >
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This check is not needed, because everywhere after isFusable tryFuse will be called, so only check in tryFuse can be left.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This check is inside isFusableMap(), which is called by isFusable() and one other place in tryToMoveChunk(). In the latter case, there is no immediate call to tryFuse() inside the function. It looks we should keep the guard. No?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

tryToMoveChunk is not making any fusions by itself, thus should not blacklist any nodes. Any changes to FusionGroup are caused by tryFuse, thus checking (or estimating, as the case may be) number of inputs/outputs only in tryFuse is better.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok. Thanks for checking. This check in isFusableMap() is removed. Just pushed a new commit. CI is starting.

if (!node->is_constant(attr::dim))
return false;

Node* list_construct = node->namedInput(attr::tensors)->node();
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

list_construct is the same as tensors_node a few lines below, you can move tensors_node line here and reuse it.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done. Will push a commit soon after testing. Thanks.

Copy link
Contributor

@facebook-github-bot facebook-github-bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@soumith is landing this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

continue;
}
any_fused = true;
AT_ASSERT(maybe_group && maybe_group == fused_cat);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The first part of this assert is meaningless now

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok. Revision is done in the latest commit. Place a limit check in canFuseWithConcat() to restore the original code sequence around tryFuse() call in fuseConcats().

return false;

auto tensors_node = node->namedInput(attr::tensors)->node();
if( (tensors_node->inputs().size() + node->outputs().size()) >
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

FWIW there's always only a single output in here, so we don't have to account for that

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since an output could still be an argument to the fused kernel generated later, I didn't change this part in the latest commit.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

BTW we should assert that this never happens, because now we have guards for it. If it does, then someone messed up somewhere.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would you suggest a specific place and assert to add?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Right here. This is the condition that should never happen in this very place

any_fused = true;
auto maybe_group = tryFuse(fused_cat, input);
if( !maybe_group ) {
continue;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please adjust canFuseWithConcat instead. The point here was to make sure that this function is sufficient for tryFuse to succeed, which is checked by the assert below.

sequence around tryFuse() call in fuseConcats().
Copy link
Contributor

@facebook-github-bot facebook-github-bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@soumith is landing this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

Copy link
Contributor

@apaszke apaszke left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please just change the condition inside the fuser (not the fuser pass) to be a hard assertion.

@royju
Copy link
Contributor Author

royju commented Apr 8, 2019

@pytorchbot retest this please.

// Have checked the limit at graph_fuser. Assert nothing else changing that.
AT_ASSERT((flat_inputs.size() + flat_outputs.size()) <=
fusion_kernel_args_limit);

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Now that this has come full circle, and compileKernel can no longer return nullopt, can you please change return type of this function, and remove nullopt handling logic in the executor?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was thinking to keep the more flexible interface able to handle the case of nullopt. But I can remove that if so desired.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The limit check in compileKernel() has been turned to an assert, and the returned value of possible nullopt from compileKernel() has also been removed in the latest commits.

@soumith
Copy link
Member

soumith commented Apr 9, 2019

@pytorchbot rebase this please

Copy link
Contributor

@facebook-github-bot facebook-github-bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@soumith is landing this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

@facebook-github-bot
Copy link
Contributor

@soumith merged this pull request in a9a29dd.

zhangguanheng66 pushed a commit to zhangguanheng66/pytorch that referenced this pull request May 6, 2019
…ytorch#18063)

Summary:
Bug fix for pytorch#15043, where a large fusion in JIT with a large number of kernel arguments, which exceeds the limit allowed by nvrtc on a cuda device.
  The fix is to check the number of arguments before a cuda kernel is generated. If the number exceeds the limit, take the runFallBack() path.
  Add a reduced test from the original issue to keep the test time low. The test would fail without this fix.
Pull Request resolved: pytorch#18063

Differential Revision: D14691401

Pulled By: soumith

fbshipit-source-id: b98829bc89ed7724e91eda82ae3a5a1151af721a
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

oncall: jit Add this issue/PR to JIT oncall triage queue open source

Projects

None yet

Development

Successfully merging this pull request may close these issues.

6 participants