-
Notifications
You must be signed in to change notification settings - Fork 25.7k
OpInfo: where #58349
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
OpInfo: where #58349
Conversation
💊 CI failures summary and remediationsAs of commit ecbfd82 (more details on the Dr. CI page):
🚧 2 fixed upstream failures:These were probably caused by upstream breakages that were already fixed.
Please rebase on the
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Too verbose? Or easy to read?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It is easy to read but at the point this is written why not just do:
yield SampleInput(make_arg(M, M), args=(make_bool_mask(M, M), make_arg(M, M)))
instead of each case?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
cases = (((M, M), (M, M), (M, M), False),
((M, 1, M), (M, M), (M, M, 1), True),
((), (), (), False),
((M, 1, M), (), (M, M, 1), True),
((), (M, M), (), True),)
def generator():
for shape, mask_shape, other_shape, broadcasts_input in cases:
yield SampleInput(make_arg(shape),
args=(make_bool_mask(mask_shape), make_arg(other_shape)),
broadcasts_input=broadcasts_input)Is it ok to stick to usual cases as tuple. Even yield for every sample input is very verbose and hard to extend the cases if need (ever) be.
Codecov Report
@@ Coverage Diff @@
## master #58349 +/- ##
==========================================
- Coverage 76.47% 76.45% -0.02%
==========================================
Files 1992 1992
Lines 199914 199922 +8
==========================================
- Hits 152875 152849 -26
- Misses 47039 47073 +34 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is going to need a comment; this is because we don't support the input being anything but the first argument, right?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes. For this specific case we can pass input as condition but since condition is always bool, so we won't gradcheck on it.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we use gradcheck_wrapper to handle the gradcheck case? Or are there other tests which would still be an issue?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Also affect test_variant_consistency_eager
Lines 303 to 307 in 28840b9
| # TODO: backward consistency only checked on sample.input, not all | |
| # tensor inputs | |
| # TODO: update to handle checking grads of all tensor inputs as | |
| # derived from each tensor output | |
| if (op.supports_autograd and isinstance(expected_forward, torch.Tensor) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
OK
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Skip will need a comment, too
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good as usual @kshitij12345, I made a few comments inline for your review
1853b2c to
a09f08c
Compare
|
@mruberry gentle ping :) |
|
|
||
| def sample_inputs_where(op_info, device, dtype, requires_grad, **kwargs): | ||
| make_arg = partial(make_tensor, dtype=dtype, device=device, requires_grad=requires_grad) | ||
| make_bool_mask = partial(make_tensor, dtype=torch.bool, device=device, requires_grad=requires_grad) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The function this is replacing guaranteed that at least one value in the mask is True
This PR should preserve that behavior. It can do so by creating its own make_bool_mask function that checks if the tensor contains at least one True value and, if not, inserts a True a "True" at a random point, for example.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is it okay to use mask_not_all_zero? It is already used by other OpInfo sample functions.
I wonder how this works for index_put?
pytorch/torch/testing/_internal/common_methods_invocations.py
Lines 1403 to 1410 in 28840b9
| # Test with mask arg | |
| mask = torch.zeros(S, dtype=torch.bool) if accumulate else mask_not_all_zeros((S,)) | |
| inputs.append(SampleInput( | |
| make_tensor((S, S), device, dtype, low=None, high=None, requires_grad=requires_grad), | |
| args=( | |
| (mask, ), | |
| make_tensor((S,), device, dtype, low=None, high=None),), | |
| kwargs=dict(accumulate=accumulate))) |
I believe that args should be on the same device. But for mask neither torch.zeros specifies dtypes not mask_not_all_zero 🤔
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sure; although eventually it'd be good to improve upon it and not just redraw
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@mruberry take a look at the updated comment. I am confused.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think some of the indexing ops do allow the indices (and maybe masks) to be on the CPU:
import torch
t = torch.tensor((1, 2, 3), device='cuda')
indices = torch.tensor((1,))
values = torch.tensor((4,), device='cuda')
t.index_put_((indices,), values)
: tensor([1, 4, 3], device='cuda:0')
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
How does this look?
def make_bool_mask(shape):
# Make sure atleast one element is nonzero,
# except for empty tensor
mask_t = make_tensor(shape, dtype=torch.bool, device=device, requires_grad=False)
if mask_t.sum() == 0:
if mask_t.numel() == 1 and len(mask_t.shape) > 0:
mask_t[0] = True
return mask_t
elif reduce(lambda x, y: x*y, mask_t.shape, 1) == 0:
return mask_t
def random_index(shape):
return tuple(map(lambda max_idx: random.randint(0, max_idx), shape))
mask_t[random_index(mask_t.shape)] = True
return mask_t
return mask_tThere was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The handling of tensors with zero elements looks good but I think that check can be simplified to mask_t having numel == 0? The handling of the general case also looks correct.
I think the tricky case, then, is if shape is () or (1,). For these I think you need two checks, one for numel() == 1 and a then a second check for len(shape) == 0. In both these cases the initial construction of the tensor is irrelevant, because you know the tensor you want to return, so I would put these checks before the make_tensor case, short-circuit on 0 numel, and then only guard the random_index case behind the call to sum()
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
if numel == 1:
if len(shape) == 0:
return torch.tensor(True, ...)
return torch.tensor((True,), ...)
mask = make_tensor(...)
if mask.numel() == 0:
return mask
if mask.sum() == 0:
# sets a random index to True
return mask
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Tried to incorporate the suggestion
How about?
def make_bool_mask(shape):
# Make sure atleast one element is nonzero,
# except for empty tensor
mask_t = make_tensor(shape, dtype=torch.bool, device=device, requires_grad=False)
if mask_t.numel() == 0:
return mask_t
elif mask_t.numel() == 1:
# Added these after make tensor so that
# we have to specify `dtype`, `device`, etc only
# once to make_tensor
mask_t.fill_(True)
return mask_t
if mask_t.sum() == 0:
def random_index(shape):
return tuple(map(lambda max_idx: random.randint(0, max_idx), shape))
mask_t[random_index(mask_t.shape)] = True
return mask_t
return mask_tThere was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Overall looks good, @kshitij12345; I made a few inline comments for your review
|
Note to self: look/fix failing NNC test. |
I think we just need to add this op to the list stated in the test. @Chillee, seems like we should put this metadata into the OpInfo so people don't have to find yet another list? |
|
@mruberry gentle ping :) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice work @kshitij12345! Thanks for the ping
|
@mruberry has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator. |
Summary: Reference: pytorch#54261 Pull Request resolved: pytorch#58349 Reviewed By: mrshenli Differential Revision: D28744220 Pulled By: mruberry fbshipit-source-id: 893a2fb88a48a60df75c7d6e2f58a42ca949daa7
Reference: #54261