-
Notifications
You must be signed in to change notification settings - Fork 25.7k
OpInfo: norm #59259
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
OpInfo: norm #59259
Conversation
💊 CI failures summary and remediationsAs of commit 80de388 (more details on the Dr. CI page):
ci.pytorch.org: 1 failedThis comment was automatically generated by Dr. CI (expand for details).Follow this link to opt-out of these comments for your Pull Requests.Please report bugs/suggestions to the (internal) Dr. CI Users group. |
| cases_negdim.append((shape, tuple(new_args), name.replace("_dim", "_neg_dim"))) | ||
|
|
||
| def generator(): | ||
| if sample_types == 'default': |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why not separate these sample input sets into different sample input functions instead of using this if/elif statement to choose between them?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Since they all belong to the same operator. I thought it is ok to have them in one function and also the cases tuple are fairly seperate. Let me know if we should split them in multiple functions.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think it's more readable if they're separate. torch.norm is really a "clearing house" for multiple functions behind the scenes
| elif sample_types == 'nuc': | ||
| for shape, args, name in cases_nuc: # type: ignore[assignment] | ||
| yield SampleInput(make_arg(shape), args=args, name=name) | ||
| elif sample_types == 'jit': |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
a "jit" category is a little odd
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is mostly ad-hoc for the cases where JIT failed. 😛 (though JIT to fro does make sense).
| ) | ||
| ), | ||
| OpInfo('norm', | ||
| variant_test_name='jit', |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
As mentioned above this variant seems a little weird.
Ideally variants would correspond to different code paths with different properties. So if nuclear norm and frobenius norm are actually different functions with different properties then they can have different OpInfos.
Maybe this can become the frobenius variant?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
(all of its inputs use the frobenius norm, since the default for norm is frobenius)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
They have different code-paths.
pytorch/aten/src/ATen/native/cpu/ReduceOpsKernel.cpp
Lines 199 to 277 in 44c20ce
| static void norm_kernel_tensor_iterator_impl( | |
| TensorIterator& iter, | |
| const Scalar& p) { | |
| // NOLINTNEXTLINE(cppcoreguidelines-init-variables) | |
| float val; | |
| if (p.isIntegral(false)) { | |
| val = p.to<int64_t>(); | |
| } else if (p.isFloatingPoint()) { | |
| // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions) | |
| val = p.to<double>(); | |
| } else { | |
| AT_ERROR("norm_kernel_tensor_iterator_impl expects norm to be integer or float"); | |
| } | |
| // In the dispatch code blocks below, reduction kernels accumulate results as | |
| // the type `acc_t`. When `scalar_t` is complex, `acc_t` is the downgraded | |
| // real number type. Otherwise, `acc_t` and `scalar_t` are the same type. | |
| if (val == 0) { | |
| AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND2(kHalf, kBFloat16, iter.input_dtype(), "norm_cpu", [&] { | |
| using acc_t = typename scalar_value_type<scalar_t>::type; | |
| binary_kernel_reduce( | |
| iter, | |
| NormZeroOps<scalar_t, acc_t>(), | |
| acc_t(0) | |
| ); | |
| }); | |
| } else if (val == 1) { | |
| AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND2(kHalf, kBFloat16, iter.input_dtype(), "norm_cpu", [&] { | |
| using acc_t = typename scalar_value_type<scalar_t>::type; | |
| binary_kernel_reduce( | |
| iter, | |
| NormOneOps<scalar_t, acc_t>(), | |
| acc_t(0) | |
| ); | |
| }); | |
| } else if (val == 2) { | |
| AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND2(kHalf, kBFloat16, iter.input_dtype(), "norm_cpu", [&] { | |
| using acc_t = typename scalar_value_type<scalar_t>::type; | |
| binary_kernel_reduce( | |
| iter, | |
| NormTwoOps<scalar_t, acc_t>(), | |
| acc_t(0) | |
| ); | |
| }); | |
| } else if (val == INFINITY) { | |
| AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND2(kHalf, kBFloat16, iter.input_dtype(), "norm_cpu", [&] { | |
| using acc_t = typename scalar_value_type<scalar_t>::type; | |
| binary_kernel_reduce( | |
| iter, | |
| AbsMaxOps<scalar_t, acc_t>(), | |
| acc_t(0) | |
| ); | |
| }); | |
| } else if (val == -INFINITY) { | |
| AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND2(kHalf, kBFloat16, iter.input_dtype(), "norm_cpu", [&] { | |
| using acc_t = typename scalar_value_type<scalar_t>::type; | |
| binary_kernel_reduce( | |
| iter, | |
| AbsMinOps<scalar_t, acc_t>(), | |
| std::numeric_limits<acc_t>::max() | |
| ); | |
| }); | |
| } else { | |
| AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND2(kHalf, kBFloat16, iter.input_dtype(), "norm_cpu", [&] { | |
| using acc_t = typename scalar_value_type<scalar_t>::type; | |
| binary_kernel_reduce( | |
| iter, | |
| NormOps<scalar_t, acc_t> { acc_t(val) }, | |
| acc_t(0) | |
| ); | |
| }); | |
| } | |
| // For complex outputs, the above kernels do not touch the imaginary values, | |
| // so we must zero them out | |
| if (isComplexType(iter.output().scalar_type())) { | |
| at::imag(iter.output()).zero_(); | |
| } | |
| } |
| def sample_inputs_norm(op_info, device, dtype, requires_grad, sample_types='default', **kwargs): | ||
| make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) | ||
|
|
||
| cases_nuc = ( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This might be more readable if there was just one sample input func per norm variant.
Another option would be to try and create a more generic generation for norm inputs. That seems tricky, however, and ultimately we plan to be rid of torch.norm, so extending its test coverage isn't especially interesting.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sure! Will split the sample func.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is a challenging and significant function to OpInfo and I think this PR is most of the way there, @kshitij12345. I made a few inline comments for readability, and I'm curious to hear your thoughts. Basically I suggest renaming "jit" to "fro" and cutting up the sample inputs for the different OpInfos to make the code a little more straightforward.
|
@mruberry have addressed the questions above. Changes
Thanks! |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice work, @kshitij12345!
|
@mruberry has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator. |
Summary: Reference: pytorch#54261 EDIT: ~~Test takes whooping 4 mins to run 😓~~ (Filtered tests also included linalg norm) Newly added tests take around 2 mins. ``` ==================================================== 193 passed, 224 skipped, 27224 deselected, 5 warnings in 138.87s (0:02:18) ==================================================== ``` Pull Request resolved: pytorch#59259 Reviewed By: jbschlosser Differential Revision: D28833962 Pulled By: mruberry fbshipit-source-id: 40b24d6a8cb8b7d231b2f6b34b87cee4f136c5f9
Reference: #54261
EDIT:
Test takes whooping 4 mins to run 😓(Filtered tests also included linalg norm)Newly added tests take around 2 mins.