KEMBAR78
[resubmit] masked_scatter thrust->cub by zasdfgbnm · Pull Request #58865 · pytorch/pytorch · GitHub
Skip to content

Conversation

@zasdfgbnm
Copy link
Collaborator

@zasdfgbnm zasdfgbnm commented May 24, 2021

See ae7760c for the fix

@facebook-github-bot
Copy link
Contributor

facebook-github-bot commented May 24, 2021

💊 CI failures summary and remediations

As of commit 53283de (more details on the Dr. CI page):


  • 1/1 failures possibly* introduced in this PR
    • 1/1 non-scanned failure(s)

This comment was automatically generated by Dr. CI (expand for details).Follow this link to opt-out of these comments for your Pull Requests.

Please report bugs/suggestions to the (internal) Dr. CI Users group.

Click here to manually regenerate this comment.

// The number of `1` elements present in the mask must be <= the
// number of elements available in `src`
TORCH_CHECK(totalElements <= srcSize, "source nElements must be == mask `1` elements");
if (srcSize == 0) {
Copy link
Collaborator

@ngimel ngimel May 24, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here you are returning without an error when possibly source was invalid (contained too few elements), but returning if mask.numel() == 0 should be fine?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think if srcSize == 0, then mask after broadcasting should also have 0 elements as well. But if srcSize is not 0, then mask.numel() can not be 0 as well, because it can not be broadcastable.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It should, but if users sent empty src (erroneously), then instead of getting en error (like today), they'd get not updated self. Here's today's behavior, and there won't be an error after this change you are making:

In [2]: a=torch.randn(4,4)

In [3]: mask = a>0

In [4]: vals = torch.randn(0)

In [5]: vals.size()
Out[5]: torch.Size([0])

In [6]: torch.masked_scatter(a, mask, vals)
---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
<ipython-input-6-35daff38b05a> in <module>
----> 1 torch.masked_scatter(a, mask, vals)

RuntimeError: Number of elements of source < number of ones in mask

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh, I see, I thought I was checking self.numel() == 0. Thanks for figuring out.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fixed now. I changed it to self.numel() == 0

@facebook-github-bot
Copy link
Contributor

@ngimel has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

@codecov
Copy link

codecov bot commented May 25, 2021

Codecov Report

Merging #58865 (53283de) into master (007fe94) will decrease coverage by 0.00%.
The diff coverage is n/a.

@@            Coverage Diff             @@
##           master   #58865      +/-   ##
==========================================
- Coverage   76.52%   76.51%   -0.01%     
==========================================
  Files        2008     2009       +1     
  Lines      201113   201827     +714     
==========================================
+ Hits       153899   154436     +537     
- Misses      47214    47391     +177     

@facebook-github-bot
Copy link
Contributor

@ngimel merged this pull request in c883334.

@zasdfgbnm zasdfgbnm deleted the masked_scatter_revert branch May 25, 2021 21:59
deniskokarev pushed a commit to deniskokarev/pytorch that referenced this pull request Jun 9, 2021
Summary:
See ae7760c for the fix

Pull Request resolved: pytorch#58865

Reviewed By: mruberry

Differential Revision: D28657940

Pulled By: ngimel

fbshipit-source-id: 9155c710b0e18ebb3bfa2dabfdd117355ac30840
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants