-
Notifications
You must be signed in to change notification settings - Fork 25.7k
Foreach Binary Test Refactor #59907
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Foreach Binary Test Refactor #59907
Conversation
💊 CI failures summary and remediationsAs of commit 39cd7b8 (more details on the Dr. CI page and at hud.pytorch.org/pr/59907): 💚 💚 Looks good so far! There are no failures yet. 💚 💚 Preview docs built from this PR This comment was automatically generated by Dr. CI (expand for details).Follow this link to opt-out of these comments for your Pull Requests.Please report bugs/suggestions to the (internal) Dr. CI Users group. |
| if same_size: | ||
| return [make_tensor((N, N), device, dtype, noncontiguous=noncontiguous) for _ in range(N)] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This option is used by test_binary_op_tensors_on_different_devices
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
First part of review.
- Please unify the tests for scalar dtypes, fast and slow path, it's hard to have a good picture of what's tested otherwise. If you want to have separate tests for this, the correct thing to do is to have decorators (but here it may be more trouble than it's worth)
- Please in a separate PR fix the upcasting issues for addition and subtraction, we should not have special handling in tests, operations have to produce same results. It's ok that regular addition will produce slightly different results than before
- Please streamline input/input2 handling, I don't see why input2 is needed.
|
Thank you for the review. I'm removing |
|
Cool, thanks, do you plan on fixing regular addition so that we don't have to have |
Yes, and I wrote crcrpar@4627e30 (using accumulate type and having requires_cast retun always False) on top of this branch and local run of test_foreach.py was successful |
|
That's cool, so can you submit PR for add separately, so that this can become simpler without the need to handle requires_cast? |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This looks much better, thanks! @mruberry and I looked over it and left a few comments, let us know what you think.
We could also merge fix to regular add to avoid conversions.
test/test_foreach.py
Outdated
| self.func = func | ||
|
|
||
| def __call__(self, inputs, **kwargs): | ||
| if len(inputs) == 2 and isinstance(inputs[1], (int, float, complex, bool)): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
you can test isintance(input[1], Number)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
didn't know that, thanks
test/test_foreach.py
Outdated
| for N in N_values: | ||
| tensors1 = self._get_test_data(device, dtype, N) | ||
| tensors2 = self._get_test_data(device, dtype, N) | ||
| def _regular_binary_test(self, dtype, op, ref, inputs, is_fastpath, *, alpha=None): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Now _regular_binary_test and _inplace_binary_test are sufficiently similar so that you can unify them into a single function, it can have an arg that controls whether to clone the inputs.
| self._regular_binary_test(dtype, op, ref, inputs, is_fastpath, alpha=alpha) | ||
| self._inplace_binary_test(dtype, inplace_op, inplace_ref, inputs, is_fastpath, alpha=alpha) | ||
|
|
||
| @skipMeta |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
is it failing the meta test?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should this be skipped for ROCm too?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes. it's failing because
- this test compares the exception if raised
- foreach with meta tensors raises
- but regular with meta tensors not raise
FAILED test/test_foreach.py::TestForeachMETA::test_binary_op_tensorlists_fastpath__foreach_add_meta_float32 - AssertionError: NotImplementedError not raised
FAILED test/test_foreach.py::TestForeachMETA::test_binary_op_tensorlists_fastpath__foreach_add_meta_float64 - AssertionError: NotImplementedError not raised
FAILED test/test_foreach.py::TestForeachMETA::test_binary_op_tensorlists_fastpath__foreach_div_meta_float32 - AssertionError: NotImplementedError not raised
FAILED test/test_foreach.py::TestForeachMETA::test_binary_op_tensorlists_fastpath__foreach_div_meta_float64 - AssertionError: NotImplementedError not raised
FAILED test/test_foreach.py::TestForeachMETA::test_binary_op_tensorlists_fastpath__foreach_mul_meta_float32 - AssertionError: NotImplementedError not raised
FAILED test/test_foreach.py::TestForeachMETA::test_binary_op_tensorlists_fastpath__foreach_mul_meta_float64 - AssertionError: NotImplementedError not raised
FAILED test/test_foreach.py::TestForeachMETA::test_binary_op_tensorlists_fastpath__foreach_sub_meta_float32 - AssertionError: NotImplementedError not raised
FAILED test/test_foreach.py::TestForeachMETA::test_binary_op_tensorlists_fastpath__foreach_sub_meta_float64 - AssertionError: NotImplementedError not raised
test/test_foreach.py
Outdated
| self._regular_binary_test(dtype, op, ref, inputs, is_fastpath) | ||
| self._inplace_binary_test(dtype, inplace_op, inplace_ref, inputs, is_fastpath) | ||
| if opinfo.supports_alpha_param: | ||
| alpha = 3 if dtype in torch.testing.get_all_int_dtypes() else 3.14 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can you create complex alpha for complex inputs? There is dtype.is_complex to check
test/test_foreach.py
Outdated
| disable_fastpath = True | ||
| self._test_binary_op_tensorlists(device, dtype, op, N, True, disable_fastpath) | ||
|
|
||
| @dtypes(*torch.testing.get_all_dtypes()) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
what are the dtypes that regular ops support, but for_each don't? Otherwise you could test fastpath and slow path for the same type.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ah, now that foreach binary ops used AT_DISPATCH_ALL_TYPES_AND_COMPLEX* dispatcher, the dtypes might be the same. double check
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
supported dtypes look same because AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND3 is used and foreach doesn't support sparse tensors at this moment.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah but that decorator doesn't make you test sparse inputs, so testing for just foreach supported dtypes (default) should be good.
test/test_foreach.py
Outdated
| if dtype in [torch.bfloat16, torch.bool, torch.float16]: | ||
| tensors2 = [torch.zeros(N, N, device=device, dtype=dtype).add(2) for _ in range(N)] | ||
| # different devices | ||
| if torch.cuda.is_available() and torch.cuda.device_count() > 1: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please check that device argument you got is cuda, torch.device(device).type == "cuda"
test/test_foreach.py
Outdated
| res = torch._foreach_add([tensor1], [tensor2]) | ||
| torch._foreach_add_([tensor1], [tensor2]) | ||
| self.assertEqual(res, [tensor1]) | ||
| tensor1 = torch.randn(10, 10, device=device).to(dtype) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
use make_tensor here
test/test_foreach.py
Outdated
| # non contiguous | ||
| tensor1 = torch.randn(5, 2, 1, 3, device=device)[:, 0] | ||
| tensor2 = torch.randn(5, 2, 1, 3, device=device)[:, 0] | ||
| tensor1 = torch.randn(5, 2, 1, 3, device=device).to(dtype)[:, 0] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
make_tensor(noncontiguous=True)
test/test_foreach.py
Outdated
| @dtypes(*torch.testing.get_all_dtypes()) | ||
| @ops(foreach_binary_op_db) | ||
| def test_binary_op_tensors_on_different_devices(self, device, dtype, op): | ||
| if self.device_type != 'cuda': |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Use @onlyCUDA decorator
test/test_foreach.py
Outdated
|
|
||
| @dtypes(*torch.testing.get_all_dtypes(include_bfloat16=True)) | ||
| def test_binary_op_tensors_on_different_devices(self, device, dtype): | ||
| @dtypes(*torch.testing.get_all_dtypes()) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
you probably don't need to run this test for all the dtypes, if the goal is just to error out for the different devices
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
sounds good and unary version doesn't use this decorator, thanks.
Summary: This PR lets `torch.add` & `torch.sub` CUDA kernels cast `alpha` to `acc_type`, not `scalar_t`. I do not remove `cast`s from `test/test_foreach.py` because I'll do this in #59907 or follow-up for it. Current upstream `torch._foreach_add` & `torch._foreach_sub` upcast `alpha` parameter to `acc_type<scalar_t>` while `torch.add` & `torch.sub` not. This is kind of problematic because outputs of `torch.add` and `torch.sub` are different from `torch._foreach_add` and `torch._foreach_sub`, respectively if the dtype of input tensors is either `torch.half` or `torch.bfloat16`. The discrepancy is proportional-ish to `abs(alpha)` except when `alpha` is representable with 16 bits. ref: - `torch._foreach_add` & `torch._foreach_sub` cast `alpha`: https://github.com/pytorch/pytorch/blob/6d0fb85a623f5ef3f3f1a2afc3660cb71fa70511/aten/src/ATen/native/cuda/ForeachBinaryOpList.cu#L21-L28, `BinaryOpListAlphaFunctor` is defined here: https://github.com/pytorch/pytorch/blob/6d0fb85a623f5ef3f3f1a2afc3660cb71fa70511/aten/src/ATen/native/cuda/ForeachFunctors.cuh#L202 related: #58833, #59907 cc ngimel ptrblck mcarilli Pull Request resolved: #60227 Reviewed By: mruberry Differential Revision: D29252759 Pulled By: ngimel fbshipit-source-id: 847f3b9493ae30a900f7445af00aef1abcc1ab21
Summary: Follow up of #60227, related to #59907 & #58833 With this pull request, `torch.add` & `torch.sub` use `acc_type` for `Scalar` if either of two arguments is `Scalar`. This mimics the behavior of [`torch.mul`](https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/cuda/BinaryMulDivKernel.cu#L18), `torch._foreach_(add|sub).Scalar` and `torch._foreach_(add|sub).ScalarList`. --- **reference** - torch.mul CUDA kernel: https://github.com/pytorch/pytorch/blob/b0c9762e2d1dfcde549344628ad6be063378ef6a/aten/src/ATen/native/cuda/BinaryMulDivKernel.cu#L17-L25 - `torch._foreach_(add|sub).Scalar`: cast scalar https://github.com/pytorch/pytorch/blob/b0c9762e2d1dfcde549344628ad6be063378ef6a/aten/src/ATen/native/cuda/ForeachBinaryOpScalar.cu#L27 - `torch._foreach_(add|sub).ScalarList`: `BinaryOpScalarListFunctor` https://github.com/pytorch/pytorch/blob/b0c9762e2d1dfcde549344628ad6be063378ef6a/aten/src/ATen/native/cuda/ForeachFunctors.cuh#L180-L182 and multi_tensor_apply handles `scalar_t` and computes `opmath_t` (almost equivalent `accscalar_t`) https://github.com/pytorch/pytorch/blob/b0c9762e2d1dfcde549344628ad6be063378ef6a/aten/src/ATen/native/cuda/MultiTensorApply.cuh#L60-L68. BinaryOpScalarListFunctor is used https://github.com/pytorch/pytorch/blob/b0c9762e2d1dfcde549344628ad6be063378ef6a/aten/src/ATen/native/cuda/ForeachBinaryOpScalarList.cu#L24 cc ngimel ptrblck mcarilli Pull Request resolved: #60454 Reviewed By: VitalyFedyunin Differential Revision: D29345035 Pulled By: ngimel fbshipit-source-id: 5dbafbdfe029a9544ec2e58f17d547928e017a04
1f088f6 to
7f6f933
Compare
Codecov Report
@@ Coverage Diff @@
## master #59907 +/- ##
===========================================
+ Coverage 59.80% 76.22% +16.42%
===========================================
Files 640 2064 +1424
Lines 83748 205482 +121734
===========================================
+ Hits 50082 156632 +106550
- Misses 33666 48850 +15184 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This looks good, only one minor comment and we can land
test/test_foreach.py
Outdated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
should it be tensor2.t()?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
oh, thank you very much
also, unify - int/float/complex/bool scalar tests -> scalar test - int/float/complex/bool scalarlist tests -> scalarlist test
also, remove @dtypes decorator from slowpath tests
7f6f933 to
39cd7b8
Compare
|
@ngimel has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator. |
Related: #58833
Changes I'm a bit concerned
TensorListScalarListMetadata<c10::complex<double>, 1>. This might be out of the scope of this pull request.cc @ptrblck @ngimel @mcarilli