-
Notifications
You must be signed in to change notification settings - Fork 25.7k
Fix bmm_sparse_cuda illegal memory access #131977
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/131977
Note: Links to docs will display an error until the docs builds have been completed. ✅ No FailuresAs of commit dca5371 with merge base 0eba7e5 ( This comment was automatically generated by Dr. CI and updates every 15 minutes. |
Can we add the test case in the example as a unit test? |
I've added a unit test now @Skylion007 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Couple of comments related to the test but overall looks like a nice simplification of the search logic here.
test/test_sparse.py
Outdated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This test should have an assert of some kind.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You're right, it would be good to have an explicit failure condition (the current version fails implicitly by raising a RuntimeError on the call to torch.bmm). Maybe we can do something like this?
indices = torch.tensor([[1], [0], [0]], device=device)
values = torch.tensor([1.], device=device)
a = torch.sparse_coo_tensor(indices, values, size=(2, 1, 1))
b = torch.zeros((2, 1, 1), device=device)
try:
_ = torch.bmm(a, b)
except RuntimeError as e:
self.fail(f"torch.bmm failed: {e}")
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It would be enough to assert on the result being correct (easy to construct the expected result by hand here). With a comment that refers to this issue for reference.
test/test_sparse.py
Outdated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It should be pretty simple to add this as a case for test_bmm
above
Since you need one sub matrix with no nonzero values, adding a test case with a single nnz and num_mats > 1
should be able to trigger it?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Would also prefer not to have a separate test case. The thing that makes this a bit tricky is that the error is silent if the memory before the index tensor has been allocated since CUDA doesn't do any bounds checking (hence the call to torch.cuda.empty_cache
at the beginning of the test). So we need to be extra careful in how we set things up in order for the error to trigger reliably (can't use self._gen_sparse
). For example, just copy-pasting the current test to the end of test_bmm
results in the exception no longer being raised. This can be fixed by either having the test at the top of test_bmm
or keeping it at the end but prepending it with del a, b, ab, ab_mat, ab_mat_check, ab_traspose_check
. However, I feel both solutions could easily break in the future, so I'm not sure what the best option would be?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ah, I didn't realize that it was so sensitive. In that case a stand-alone test like this is warranted.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What about a comment explaining the need for a standalone test? Seems like a non-trivial matter that takes place here. Could be useful to preserve...
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the review @nikitaved! I've added a comment to the test
Appreciate the feedback @amjames, I've responded inline with a few follow-up comments and questions. There's another problem with how the test works that I didn't realize before: The IMA makes CUDA enter an error state that causes subsequent tests on the same CUDA device to also fail. It's currently not possible to recover from this state as far as I can tell (related discussion in #72117), so I'm not sure if triggering an IMA is the right way to go. Unfortunately, I've not found another way to test for this error - as long as the old algorithm stays within bounds, it does return the correct result. |
This PR is a fix for the IMA, so the test would only trigger that if a future modification regressed in some way. I would say that is the intended purpose of the test. The follow on failures are unavoidable, but it still does the job of signaling if some other change reintroduces the bug. I don't see any problem with that. |
Got it, that makes sense @amjames. I've added an assert statement to the test now. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm happy with this now. We will still need a review from someone on the "approved" list before merging though.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
391c7e3
to
10674ff
Compare
Thanks for the reviews! I've rebased the code @pearu |
@pytorchbot merge -r |
You don't have permissions to rebase this PR since you are a first time contributor. If you think this is a mistake, please contact PyTorch Dev Infra. |
Pull workflow has not been scheduled for the PR yet. It could be because author doesn't have permissions to run those or skip-checks keywords were added to PR/commits, aborting merge. Please get/give approval for the workflows and/or remove skip ci decorators before next merge attempt. If you think this is a mistake, please contact PyTorch Dev Infra. |
@Skylion007 @amjames Do we need to rerun workflows? Sorry for ping - not quite sure how this works |
@pytorchbot merge |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
Merge failedReason: 15 mandatory check(s) failed. The first few are:
Dig deeper by viewing the failures on hud |
@pytorchbot rebase |
You don't have permissions to rebase this PR since you are a first time contributor. If you think this is a mistake, please contact PyTorch Dev Infra. |
@pytorchbot merge -f "unrelated failures" |
You are not authorized to force merges to this repository. Please use the regular |
50c80ad
to
dca5371
Compare
@pytorchbot merge |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
Merge failedReason: 3 mandatory check(s) failed. The first few are: Dig deeper by viewing the failures on hud |
@pearu can you help me with approving the workflows again? 🙏 I've rebased on main, hopefully that will resolve the previous failures |
@pytorchbot merge |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
This PR fixes a bug in
search_end_matrix_indices_cuda_kernel
causing an illegal memory access when callingbmm_sparse_cuda
on a sparse matrix with no non-zero values in the first batch dimension. Reproducible example:Details
In the previous code, we may for example end up with the following situation:
When
target_mat_num = 0
, the next iteration of the while loop will assign-1
toend_idx
and thus(0 + (-1)) >> 1 = -1
tomid_idx
, causing an access error on line 703. The updated code maintains the invariantstart_idx <= end_idx
and will not go out of bounds.cc @alexsamardzic @nikitaved @pearu @cpuhrsch @amjames @bhosmer @jcaip