KEMBAR78
Foreach Binary Test Refactor by crcrpar · Pull Request #59907 · pytorch/pytorch · GitHub
Skip to content

Conversation

@crcrpar
Copy link
Collaborator

@crcrpar crcrpar commented Jun 12, 2021

Related: #58833

Changes I'm a bit concerned

cc @ptrblck @ngimel @mcarilli

@facebook-github-bot
Copy link
Contributor

facebook-github-bot commented Jun 12, 2021

💊 CI failures summary and remediations

As of commit 39cd7b8 (more details on the Dr. CI page and at hud.pytorch.org/pr/59907):


💚 💚 Looks good so far! There are no failures yet. 💚 💚


Preview docs built from this PR

This comment was automatically generated by Dr. CI (expand for details).Follow this link to opt-out of these comments for your Pull Requests.

Please report bugs/suggestions to the (internal) Dr. CI Users group.

Click here to manually regenerate this comment.

Comment on lines +2327 to +2688
if same_size:
return [make_tensor((N, N), device, dtype, noncontiguous=noncontiguous) for _ in range(N)]
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@crcrpar crcrpar mentioned this pull request Jun 12, 2021
28 tasks
@ailzhang ailzhang requested review from mcarilli, ngimel and ptrblck June 14, 2021 17:27
@ailzhang ailzhang added the triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module label Jun 14, 2021
Copy link
Collaborator

@ngimel ngimel left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

First part of review.

  1. Please unify the tests for scalar dtypes, fast and slow path, it's hard to have a good picture of what's tested otherwise. If you want to have separate tests for this, the correct thing to do is to have decorators (but here it may be more trouble than it's worth)
  2. Please in a separate PR fix the upcasting issues for addition and subtraction, we should not have special handling in tests, operations have to produce same results. It's ok that regular addition will produce slightly different results than before
  3. Please streamline input/input2 handling, I don't see why input2 is needed.

@crcrpar
Copy link
Collaborator Author

crcrpar commented Jun 15, 2021

Thank you for the review. I'm removing input2 thing and unifying different scalar/scalarlist-s tests into one. So there'll be test_binary_op_tensorlists_(fast|slow)path, test_binary_op_scalar_(fast|slow)path, and test_binary_op_scalarlist_(fast|slow)path.

@ngimel
Copy link
Collaborator

ngimel commented Jun 17, 2021

Cool, thanks, do you plan on fixing regular addition so that we don't have to have requires_cast things?

@crcrpar
Copy link
Collaborator Author

crcrpar commented Jun 17, 2021

Cool, thanks, do you plan on fixing regular addition so that we don't have to have requires_cast things?

Yes, and I wrote crcrpar@4627e30 (using accumulate type and having requires_cast retun always False) on top of this branch and local run of test_foreach.py was successful

@ngimel
Copy link
Collaborator

ngimel commented Jun 17, 2021

That's cool, so can you submit PR for add separately, so that this can become simpler without the need to handle requires_cast?

Copy link
Collaborator

@ngimel ngimel left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks much better, thanks! @mruberry and I looked over it and left a few comments, let us know what you think.
We could also merge fix to regular add to avoid conversions.

self.func = func

def __call__(self, inputs, **kwargs):
if len(inputs) == 2 and isinstance(inputs[1], (int, float, complex, bool)):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you can test isintance(input[1], Number)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

didn't know that, thanks

for N in N_values:
tensors1 = self._get_test_data(device, dtype, N)
tensors2 = self._get_test_data(device, dtype, N)
def _regular_binary_test(self, dtype, op, ref, inputs, is_fastpath, *, alpha=None):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Now _regular_binary_test and _inplace_binary_test are sufficiently similar so that you can unify them into a single function, it can have an arg that controls whether to clone the inputs.

self._regular_binary_test(dtype, op, ref, inputs, is_fastpath, alpha=alpha)
self._inplace_binary_test(dtype, inplace_op, inplace_ref, inputs, is_fastpath, alpha=alpha)

@skipMeta
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is it failing the meta test?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should this be skipped for ROCm too?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes. it's failing because

  • this test compares the exception if raised
  • foreach with meta tensors raises
  • but regular with meta tensors not raise
FAILED test/test_foreach.py::TestForeachMETA::test_binary_op_tensorlists_fastpath__foreach_add_meta_float32 - AssertionError: NotImplementedError not raised
FAILED test/test_foreach.py::TestForeachMETA::test_binary_op_tensorlists_fastpath__foreach_add_meta_float64 - AssertionError: NotImplementedError not raised
FAILED test/test_foreach.py::TestForeachMETA::test_binary_op_tensorlists_fastpath__foreach_div_meta_float32 - AssertionError: NotImplementedError not raised
FAILED test/test_foreach.py::TestForeachMETA::test_binary_op_tensorlists_fastpath__foreach_div_meta_float64 - AssertionError: NotImplementedError not raised
FAILED test/test_foreach.py::TestForeachMETA::test_binary_op_tensorlists_fastpath__foreach_mul_meta_float32 - AssertionError: NotImplementedError not raised
FAILED test/test_foreach.py::TestForeachMETA::test_binary_op_tensorlists_fastpath__foreach_mul_meta_float64 - AssertionError: NotImplementedError not raised
FAILED test/test_foreach.py::TestForeachMETA::test_binary_op_tensorlists_fastpath__foreach_sub_meta_float32 - AssertionError: NotImplementedError not raised
FAILED test/test_foreach.py::TestForeachMETA::test_binary_op_tensorlists_fastpath__foreach_sub_meta_float64 - AssertionError: NotImplementedError not raised

self._regular_binary_test(dtype, op, ref, inputs, is_fastpath)
self._inplace_binary_test(dtype, inplace_op, inplace_ref, inputs, is_fastpath)
if opinfo.supports_alpha_param:
alpha = 3 if dtype in torch.testing.get_all_int_dtypes() else 3.14
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can you create complex alpha for complex inputs? There is dtype.is_complex to check

disable_fastpath = True
self._test_binary_op_tensorlists(device, dtype, op, N, True, disable_fastpath)

@dtypes(*torch.testing.get_all_dtypes())
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what are the dtypes that regular ops support, but for_each don't? Otherwise you could test fastpath and slow path for the same type.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ah, now that foreach binary ops used AT_DISPATCH_ALL_TYPES_AND_COMPLEX* dispatcher, the dtypes might be the same. double check

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

supported dtypes look same because AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND3 is used and foreach doesn't support sparse tensors at this moment.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah but that decorator doesn't make you test sparse inputs, so testing for just foreach supported dtypes (default) should be good.

if dtype in [torch.bfloat16, torch.bool, torch.float16]:
tensors2 = [torch.zeros(N, N, device=device, dtype=dtype).add(2) for _ in range(N)]
# different devices
if torch.cuda.is_available() and torch.cuda.device_count() > 1:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please check that device argument you got is cuda, torch.device(device).type == "cuda"

res = torch._foreach_add([tensor1], [tensor2])
torch._foreach_add_([tensor1], [tensor2])
self.assertEqual(res, [tensor1])
tensor1 = torch.randn(10, 10, device=device).to(dtype)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

use make_tensor here

# non contiguous
tensor1 = torch.randn(5, 2, 1, 3, device=device)[:, 0]
tensor2 = torch.randn(5, 2, 1, 3, device=device)[:, 0]
tensor1 = torch.randn(5, 2, 1, 3, device=device).to(dtype)[:, 0]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

make_tensor(noncontiguous=True)

@dtypes(*torch.testing.get_all_dtypes())
@ops(foreach_binary_op_db)
def test_binary_op_tensors_on_different_devices(self, device, dtype, op):
if self.device_type != 'cuda':
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Use @onlyCUDA decorator


@dtypes(*torch.testing.get_all_dtypes(include_bfloat16=True))
def test_binary_op_tensors_on_different_devices(self, device, dtype):
@dtypes(*torch.testing.get_all_dtypes())
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you probably don't need to run this test for all the dtypes, if the goal is just to error out for the different devices

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sounds good and unary version doesn't use this decorator, thanks.

facebook-github-bot pushed a commit that referenced this pull request Jun 21, 2021
Summary:
This PR lets `torch.add` & `torch.sub` CUDA kernels cast `alpha` to `acc_type`, not `scalar_t`.
I do not remove `cast`s from `test/test_foreach.py` because I'll do this in #59907 or follow-up for it.

Current upstream `torch._foreach_add` & `torch._foreach_sub` upcast `alpha` parameter to `acc_type<scalar_t>` while `torch.add` & `torch.sub` not. This is kind of problematic because outputs of `torch.add` and `torch.sub` are different from `torch._foreach_add` and `torch._foreach_sub`, respectively if the dtype of input tensors is either `torch.half` or `torch.bfloat16`. The discrepancy is proportional-ish to `abs(alpha)` except when `alpha` is representable with 16 bits.

ref:
- `torch._foreach_add` & `torch._foreach_sub` cast `alpha`: https://github.com/pytorch/pytorch/blob/6d0fb85a623f5ef3f3f1a2afc3660cb71fa70511/aten/src/ATen/native/cuda/ForeachBinaryOpList.cu#L21-L28, `BinaryOpListAlphaFunctor` is defined here: https://github.com/pytorch/pytorch/blob/6d0fb85a623f5ef3f3f1a2afc3660cb71fa70511/aten/src/ATen/native/cuda/ForeachFunctors.cuh#L202

related: #58833, #59907

cc ngimel ptrblck mcarilli

Pull Request resolved: #60227

Reviewed By: mruberry

Differential Revision: D29252759

Pulled By: ngimel

fbshipit-source-id: 847f3b9493ae30a900f7445af00aef1abcc1ab21
facebook-github-bot pushed a commit that referenced this pull request Jun 24, 2021
Summary:
Follow up of #60227, related to #59907 & #58833

With this pull request, `torch.add` & `torch.sub` use `acc_type` for `Scalar` if either of two arguments is `Scalar`.
This mimics the behavior of [`torch.mul`](https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/cuda/BinaryMulDivKernel.cu#L18), `torch._foreach_(add|sub).Scalar` and `torch._foreach_(add|sub).ScalarList`.

 ---

**reference**
- torch.mul CUDA kernel: https://github.com/pytorch/pytorch/blob/b0c9762e2d1dfcde549344628ad6be063378ef6a/aten/src/ATen/native/cuda/BinaryMulDivKernel.cu#L17-L25
- `torch._foreach_(add|sub).Scalar`: cast scalar https://github.com/pytorch/pytorch/blob/b0c9762e2d1dfcde549344628ad6be063378ef6a/aten/src/ATen/native/cuda/ForeachBinaryOpScalar.cu#L27
- `torch._foreach_(add|sub).ScalarList`: `BinaryOpScalarListFunctor` https://github.com/pytorch/pytorch/blob/b0c9762e2d1dfcde549344628ad6be063378ef6a/aten/src/ATen/native/cuda/ForeachFunctors.cuh#L180-L182 and multi_tensor_apply handles `scalar_t` and computes `opmath_t` (almost equivalent `accscalar_t`)  https://github.com/pytorch/pytorch/blob/b0c9762e2d1dfcde549344628ad6be063378ef6a/aten/src/ATen/native/cuda/MultiTensorApply.cuh#L60-L68. BinaryOpScalarListFunctor
is used https://github.com/pytorch/pytorch/blob/b0c9762e2d1dfcde549344628ad6be063378ef6a/aten/src/ATen/native/cuda/ForeachBinaryOpScalarList.cu#L24

cc ngimel ptrblck mcarilli

Pull Request resolved: #60454

Reviewed By: VitalyFedyunin

Differential Revision: D29345035

Pulled By: ngimel

fbshipit-source-id: 5dbafbdfe029a9544ec2e58f17d547928e017a04
@crcrpar crcrpar force-pushed the fe/refactor-binary-test branch 2 times, most recently from 1f088f6 to 7f6f933 Compare June 28, 2021 04:59
@codecov
Copy link

codecov bot commented Jun 28, 2021

Codecov Report

Merging #59907 (7f6f933) into master (6e9e30c) will increase coverage by 16.42%.
The diff coverage is 83.33%.

❗ Current head 7f6f933 differs from pull request most recent head 39cd7b8. Consider uploading reports for the commit 39cd7b8 to get more accurate results

@@             Coverage Diff             @@
##           master   #59907       +/-   ##
===========================================
+ Coverage   59.80%   76.22%   +16.42%     
===========================================
  Files         640     2064     +1424     
  Lines       83748   205482   +121734     
===========================================
+ Hits        50082   156632   +106550     
- Misses      33666    48850    +15184     

Copy link
Collaborator

@ngimel ngimel left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks good, only one minor comment and we can land

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should it be tensor2.t()?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oh, thank you very much

@crcrpar crcrpar force-pushed the fe/refactor-binary-test branch from 7f6f933 to 39cd7b8 Compare July 2, 2021 03:33
@facebook-github-bot
Copy link
Contributor

@ngimel has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

@facebook-github-bot
Copy link
Contributor

@ngimel merged this pull request in fac744e.

@crcrpar crcrpar deleted the fe/refactor-binary-test branch July 6, 2021 21:03
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

cla signed Merged open source triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants