KEMBAR78
Fix bmm_sparse_cuda illegal memory access by ludvb · Pull Request #131977 · pytorch/pytorch · GitHub
Skip to content

Conversation

@ludvb
Copy link
Contributor

@ludvb ludvb commented Jul 27, 2024

This PR fixes a bug in search_end_matrix_indices_cuda_kernel causing an illegal memory access when calling bmm_sparse_cuda on a sparse matrix with no non-zero values in the first batch dimension. Reproducible example:

import torch

ind = torch.tensor([[1], [0], [0]], device="cuda")
val = torch.tensor([1.], device="cuda")
A = torch.sparse_coo_tensor(ind, val, size=(2, 1, 1))
B = torch.zeros((2, 1, 1), device="cuda")
C = torch.bmm(A, B)

Details

In the previous code, we may for example end up with the following situation:

i : indices_1D[i]
------------------------------------------
0 : 1                <- start_idx, mid_idx
1 : 1                <- end_idx
...

When target_mat_num = 0, the next iteration of the while loop will assign -1 to end_idx and thus (0 + (-1)) >> 1 = -1 to mid_idx, causing an access error on line 703. The updated code maintains the invariant start_idx <= end_idx and will not go out of bounds.

cc @alexsamardzic @nikitaved @pearu @cpuhrsch @amjames @bhosmer @jcaip

@pytorch-bot
Copy link

pytorch-bot bot commented Jul 27, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/131977

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit dca5371 with merge base 0eba7e5 (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@linux-foundation-easycla
Copy link

linux-foundation-easycla bot commented Jul 27, 2024

CLA Signed

The committers listed above are authorized under a signed CLA.

@pytorch-bot pytorch-bot bot added the release notes: sparse release notes category label Jul 27, 2024
@Skylion007
Copy link
Collaborator

Can we add the test case in the example as a unit test?

@jbschlosser jbschlosser added the module: sparse Related to torch.sparse label Jul 30, 2024
@jbschlosser jbschlosser requested review from jcaip, nikitaved and pearu July 30, 2024 15:39
@jbschlosser jbschlosser added the triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module label Jul 30, 2024
@ludvb
Copy link
Contributor Author

ludvb commented Aug 2, 2024

I've added a unit test now @Skylion007

Copy link
Collaborator

@amjames amjames left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Couple of comments related to the test but overall looks like a nice simplification of the search logic here.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This test should have an assert of some kind.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You're right, it would be good to have an explicit failure condition (the current version fails implicitly by raising a RuntimeError on the call to torch.bmm). Maybe we can do something like this?

indices = torch.tensor([[1], [0], [0]], device=device)
values = torch.tensor([1.], device=device)
a = torch.sparse_coo_tensor(indices, values, size=(2, 1, 1))
b = torch.zeros((2, 1, 1), device=device)
try:
    _ = torch.bmm(a, b)
except RuntimeError as e:
    self.fail(f"torch.bmm failed: {e}")

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It would be enough to assert on the result being correct (easy to construct the expected result by hand here). With a comment that refers to this issue for reference.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It should be pretty simple to add this as a case for test_bmm above

Since you need one sub matrix with no nonzero values, adding a test case with a single nnz and num_mats > 1 should be able to trigger it?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would also prefer not to have a separate test case. The thing that makes this a bit tricky is that the error is silent if the memory before the index tensor has been allocated since CUDA doesn't do any bounds checking (hence the call to torch.cuda.empty_cache at the beginning of the test). So we need to be extra careful in how we set things up in order for the error to trigger reliably (can't use self._gen_sparse). For example, just copy-pasting the current test to the end of test_bmm results in the exception no longer being raised. This can be fixed by either having the test at the top of test_bmm or keeping it at the end but prepending it with del a, b, ab, ab_mat, ab_mat_check, ab_traspose_check. However, I feel both solutions could easily break in the future, so I'm not sure what the best option would be?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, I didn't realize that it was so sensitive. In that case a stand-alone test like this is warranted.

Copy link
Collaborator

@nikitaved nikitaved Aug 31, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What about a comment explaining the need for a standalone test? Seems like a non-trivial matter that takes place here. Could be useful to preserve...

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the review @nikitaved! I've added a comment to the test

@ludvb
Copy link
Contributor Author

ludvb commented Aug 7, 2024

Appreciate the feedback @amjames, I've responded inline with a few follow-up comments and questions.

There's another problem with how the test works that I didn't realize before: The IMA makes CUDA enter an error state that causes subsequent tests on the same CUDA device to also fail. It's currently not possible to recover from this state as far as I can tell (related discussion in #72117), so I'm not sure if triggering an IMA is the right way to go. Unfortunately, I've not found another way to test for this error - as long as the old algorithm stays within bounds, it does return the correct result.

@amjames
Copy link
Collaborator

amjames commented Aug 8, 2024

There's another problem with how the test works that I didn't realize before: The IMA makes CUDA enter an error state that causes subsequent tests on the same CUDA device to also fail. It's currently not possible to recover from this state as far as I can tell (related discussion in #72117), so I'm not sure if triggering an IMA is the right way to go. Unfortunately, I've not found another way to test for this error - as long as the old algorithm stays within bounds, it does return the correct result.

This PR is a fix for the IMA, so the test would only trigger that if a future modification regressed in some way. I would say that is the intended purpose of the test. The follow on failures are unavoidable, but it still does the job of signaling if some other change reintroduces the bug. I don't see any problem with that.

@ludvb
Copy link
Contributor Author

ludvb commented Aug 9, 2024

Got it, that makes sense @amjames. I've added an assert statement to the test now.

Copy link
Collaborator

@amjames amjames left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm happy with this now. We will still need a review from someone on the "approved" list before merging though.

@amjames amjames requested a review from cpuhrsch August 9, 2024 14:35
Copy link
Collaborator

@pearu pearu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I confirm that the new search_end_matrix_indices algorithm is correct. Thanks, @ludvb!

@ludvb could rebase the PR against main branch?

@ludvb ludvb force-pushed the end_matrix_indices branch from 391c7e3 to 10674ff Compare August 30, 2024 19:57
@ludvb
Copy link
Contributor Author

ludvb commented Aug 30, 2024

Thanks for the reviews! I've rebased the code @pearu

@ludvb
Copy link
Contributor Author

ludvb commented Sep 26, 2024

@pytorchbot merge -r

@pytorch-bot
Copy link

pytorch-bot bot commented Sep 26, 2024

You don't have permissions to rebase this PR since you are a first time contributor. If you think this is a mistake, please contact PyTorch Dev Infra.

@pytorch-bot
Copy link

pytorch-bot bot commented Sep 26, 2024

Pull workflow has not been scheduled for the PR yet. It could be because author doesn't have permissions to run those or skip-checks keywords were added to PR/commits, aborting merge. Please get/give approval for the workflows and/or remove skip ci decorators before next merge attempt. If you think this is a mistake, please contact PyTorch Dev Infra.

@ludvb
Copy link
Contributor Author

ludvb commented Sep 26, 2024

@Skylion007 @amjames Do we need to rerun workflows? Sorry for ping - not quite sure how this works

@pearu pearu added the ciflow/trunk Trigger trunk jobs on your pull request label Sep 26, 2024
@ludvb
Copy link
Contributor Author

ludvb commented Oct 6, 2024

@pytorchbot merge

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

@pytorchmergebot
Copy link
Collaborator

@ludvb
Copy link
Contributor Author

ludvb commented Oct 6, 2024

@pytorchbot rebase

@pytorch-bot
Copy link

pytorch-bot bot commented Oct 6, 2024

You don't have permissions to rebase this PR since you are a first time contributor. If you think this is a mistake, please contact PyTorch Dev Infra.

@ludvb
Copy link
Contributor Author

ludvb commented Oct 6, 2024

@pytorchbot merge -f "unrelated failures"

@pytorch-bot
Copy link

pytorch-bot bot commented Oct 6, 2024

You are not authorized to force merges to this repository. Please use the regular @pytorchmergebot merge command instead

@ludvb ludvb force-pushed the end_matrix_indices branch from 50c80ad to dca5371 Compare October 6, 2024 14:58
@ludvb
Copy link
Contributor Author

ludvb commented Oct 7, 2024

@pytorchbot merge

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

@pytorchmergebot
Copy link
Collaborator

Merge failed

Reason: 3 mandatory check(s) failed. The first few are:

Dig deeper by viewing the failures on hud

Details for Dev Infra team Raised by workflow job

Failing merge rule: Core Maintainers

@ludvb
Copy link
Contributor Author

ludvb commented Oct 7, 2024

@pearu can you help me with approving the workflows again? 🙏 I've rebased on main, hopefully that will resolve the previous failures

@ludvb
Copy link
Contributor Author

ludvb commented Oct 7, 2024

@pytorchbot merge

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ciflow/trunk Trigger trunk jobs on your pull request Merged module: sparse Related to torch.sparse open source release notes: sparse release notes category triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Projects

None yet

Development

Successfully merging this pull request may close these issues.

8 participants