-
Notifications
You must be signed in to change notification settings - Fork 25.7k
[resubmit] masked_scatter thrust->cub #58865
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
💊 CI failures summary and remediationsAs of commit 53283de (more details on the Dr. CI page):
This comment was automatically generated by Dr. CI (expand for details).Follow this link to opt-out of these comments for your Pull Requests.Please report bugs/suggestions to the (internal) Dr. CI Users group. |
| // The number of `1` elements present in the mask must be <= the | ||
| // number of elements available in `src` | ||
| TORCH_CHECK(totalElements <= srcSize, "source nElements must be == mask `1` elements"); | ||
| if (srcSize == 0) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Here you are returning without an error when possibly source was invalid (contained too few elements), but returning if mask.numel() == 0 should be fine?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think if srcSize == 0, then mask after broadcasting should also have 0 elements as well. But if srcSize is not 0, then mask.numel() can not be 0 as well, because it can not be broadcastable.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It should, but if users sent empty src (erroneously), then instead of getting en error (like today), they'd get not updated self. Here's today's behavior, and there won't be an error after this change you are making:
In [2]: a=torch.randn(4,4)
In [3]: mask = a>0
In [4]: vals = torch.randn(0)
In [5]: vals.size()
Out[5]: torch.Size([0])
In [6]: torch.masked_scatter(a, mask, vals)
---------------------------------------------------------------------------
RuntimeError Traceback (most recent call last)
<ipython-input-6-35daff38b05a> in <module>
----> 1 torch.masked_scatter(a, mask, vals)
RuntimeError: Number of elements of source < number of ones in mask
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oh, I see, I thought I was checking self.numel() == 0. Thanks for figuring out.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
fixed now. I changed it to self.numel() == 0
|
@ngimel has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator. |
Codecov Report
@@ Coverage Diff @@
## master #58865 +/- ##
==========================================
- Coverage 76.52% 76.51% -0.01%
==========================================
Files 2008 2009 +1
Lines 201113 201827 +714
==========================================
+ Hits 153899 154436 +537
- Misses 47214 47391 +177 |
Summary: See ae7760c for the fix Pull Request resolved: pytorch#58865 Reviewed By: mruberry Differential Revision: D28657940 Pulled By: ngimel fbshipit-source-id: 9155c710b0e18ebb3bfa2dabfdd117355ac30840
See ae7760c for the fix