-
Notifications
You must be signed in to change notification settings - Fork 25.7k
fix torch.linalg.norm and torch.norm for torch.complex32 datatype #133661
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/133661
Note: Links to docs will display an error until the docs builds have been completed. ✅ No FailuresAs of commit 9cb69b2 with merge base 028c5d3 ( This comment was automatically generated by Dr. CI and updates every 15 minutes. |
|
@mingfeima , could you please review this PR? Thanks! |
|
We probably need to add these dtypes to our current testing framework. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, just remember to update test cases.
| for ord in vector_ords: | ||
| res = torch.linalg.norm(x, ord, keepdim=keepdim) | ||
| res_float = torch.linalg.norm(x_cfloat, ord, keepdim=keepdim) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Isn't this op tested with OpInfos? Shall we add tests for the CPU there instead?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I got this warning when this op tested with OpInfos, ComplexHalf seems not fully supported yet:
/home/jiayisun/pytorch/torch/testing/_creation.py:233: UserWarning: ComplexHalf support is experimental and many operators don't support it yet. (Triggered internally at /home/jiayisun/pytorch/aten/src/ATen/EmptyTensor.cpp:46.)
result = torch.empty(shape, device=device, dtype=dtype)
Some dtypes for norm on device type cpu are only partially supported!
The following dtypes only worked on some samples during backward: {torch.complex32}.
| sample_inputs_func=sample_inputs_norm, | ||
| dtypes=floating_and_complex_types_and(torch.float16, torch.bfloat16), | ||
| dtypes=floating_and_complex_types_and(torch.float16, torch.bfloat16, torch.chalf), | ||
| dtypesIfCUDA=floating_and_complex_types_and(torch.float16, torch.bfloat16), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we just add chalf support to CUDA too while we are at it? We already support complex types on GPU so should be an easy enough change. Just add kComplexHalf to the cuda operator as well and delete this line. :)
| AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES(iter.input_dtype(), "norm_cuda", [&] { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No, issue #132634 also exists on CUDA, there is the same bug in GPU kernel, it runs in here: https://github.com/pytorch/pytorch/blob/main/aten/src/ATen/native/cuda/ReduceNormKernel.cu#L35, so we cannot fix it by modifying this line, I am not familiar with GPU kernels, maybe you can open a new PR to fix it.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@Skylion007 can you please review this PR again? thanks!
| } | ||
|
|
||
| AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES(iter.input_dtype(), "norm_cpu", [&] { | ||
| AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND3(kHalf, kBFloat16, kComplexHalf, iter.input_dtype(), "norm_cpu", [&] { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Are the kHalf and kBfloat16 types in the AND3 redundant? It seemed to support those types before??? or it suggests we aren't testing the right thing.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is pretty complicated logic, let's update the GPU operator at the same time for chalf for parity.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Are the kHalf and kBfloat16 types in the AND3 redundant? It seemed to support those types before??? or it suggests we aren't testing the right thing.
Because I modified the logic above, now kHalf and kBfloat16 types may also be here (the case of iter.input_dtype() == kHalf/kBFloat16 and iter.dtype(0)!= kFloat)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I checked the UT is for this operator, but I got this warning when run the UT, ComplexHalf seems not fully supported yet:
/home/jiayisun/pytorch/torch/testing/_creation.py:233: UserWarning: ComplexHalf support is experimental and many operators don't support it yet. (Triggered internally at /home/jiayisun/pytorch/aten/src/ATen/EmptyTensor.cpp:46.)
result = torch.empty(shape, device=device, dtype=dtype)
Some dtypes for norm on device type cpu are only partially supported!
The following dtypes only worked on some samples during backward: {torch.complex32}.
|
hi @Skylion007 , could you please review this PR again? thanks! |
|
@pytorchbot merge |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
…torch#133661) Fix pytorch#132634. Pull Request resolved: pytorch#133661 Approved by: https://github.com/mingfeima, https://github.com/Skylion007
Stack from ghstack (oldest at bottom):
Fix #132634.
cc @jgong5 @mingfeima @XiaobingSuper @sanchitintel @ashokei @jingxu10