-
Notifications
You must be signed in to change notification settings - Fork 25.7k
[cpp wrapper] add AOTI shim for collective ops #154492
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/154492
Note: Links to docs will display an error until the docs builds have been completed. ✅ You can merge normally! (1 Unrelated Failure)As of commit c44381a with merge base 04178d3 ( UNSTABLE - The following job is marked as unstable, possibly due to flakiness on trunk:
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
Attention! PyTorch one of the C-stable API file was changedYou MUST NOT change existing function declarations in this, as this header defines a stable C ABI. If you need to change the signature for a function, introduce a new v2 version of the function and modify code generation to target the new version of the function. Caused by: |
b34bd4c to
6e9d8ed
Compare
| from torch.testing._internal.inductor_utils import HAS_CPU | ||
|
|
||
|
|
||
| def load_test_module(name): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
has this function be defined else where? Can we reuse that?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks. The change is removed according to Baobin's suggestion.
| return ( | ||
| type(node) == ir._CollectiveKernel and (op is None or node.op_overload is op) | ||
| isinstance(node, ir._CollectiveKernel) | ||
| and not isinstance(node, ir._WaitKernel) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
From the definition of ir._WaitKernel which is a subclass of ir._CollectiveKernel, why it will be excluded here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Because ir._WaitKernel does not have any constant_args like other collective kernels.
The function is_collective is used here https://github.com/pytorch/pytorch/blob/main/test/inductor/test_snode_runtime.py#L244. If the check is True, it will finally call _get_group_size_by_name(node.constant_args[-1]) https://github.com/pytorch/pytorch/blob/main/torch/_inductor/comm_analysis.py#L73, which raises an error for ir._WaitKernel.
Ditto for the change of get_collective_group_size in common_analysis.py.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Then shouldn't get_collective_group_size handle handles _WaitKernel differently? (disclaimer: I am not sure if _WaitKernel has a concept of group size)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
_WaitKernel is a special _CollectiveKernel, which does not have group_size (generate from group_name) like others.
For example, you can see the diff between schemas of wait_tensor and all_reduce:
"wait_tensor(Tensor tensor) -> Tensor"
"all_reduce(Tensor input, str reduce_op, str group_name) -> Tensor"
|
@desertfire @yifuwang Could you have a look at this PR? Thanks! |
| AtenTensorHandle qScaleAndZeros, | ||
| AtenTensorHandle* ret0); | ||
|
|
||
| AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu__c10d_functional_all_reduce_( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I wish there is a way to autogen these...
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks! As collective OPs are not defined under torch.ops.aten, their AOTI shims cannot be automatically generated. Do you have any suggestions for how to autogen these kernels?
| testdir = Path(__file__).absolute().parent.parent | ||
| with mock.patch("sys.path", [*sys.path, str(testdir)]): | ||
| return SourceFileLoader( | ||
| name, str(testdir / f"{name.replace('.', '/')}.py") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This looks fragile. Can you simply test cpp_wrapper in test_c10d_functional_native.py?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks and modified!
| kernel = self.op_overload | ||
| self.cpp_kernel_name = kernel._schema.name | ||
| if cpp_kernel_name is not None: | ||
| self.cpp_kernel_name = cpp_kernel_name |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why do you need to explicitly pass in cpp_kernel_name here? cpp_wrapper_cpu.py should have taken care of the schema name to C shim name conversion.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for your comment.
In function set_cpp_kernel_name, the optional param cpp_kernel_name is not used. I suppose we need to take it into account when this is explicitly pass.
As collective OPs are not defined under torch.ops.aten, their AOTI shims cannot be automatically generated. So here I need to pass the cpp_kernel_name like what this does: https://github.com/pytorch/pytorch/blob/main/torch/_inductor/mkldnn_ir.py#L1265.
| return ( | ||
| type(node) == ir._CollectiveKernel and (op is None or node.op_overload is op) | ||
| isinstance(node, ir._CollectiveKernel) | ||
| and not isinstance(node, ir._WaitKernel) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Then shouldn't get_collective_group_size handle handles _WaitKernel differently? (disclaimer: I am not sure if _WaitKernel has a concept of group size)
| AtenTensorHandle* ret0) { | ||
| AOTI_TORCH_CONVERT_EXCEPTION_TO_ERROR_CODE({ | ||
| auto tmp_result = c10d::all_reduce_( | ||
| *tensor_handle_to_tensor_pointer(inp), reduce_op, group_name); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please use resolve_tensor_dispatch_flags instead of tensor_handle_to_tensor_pointer.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The modification would raise an error: cannot bind non-const lvalue reference of type ‘at::Tensor&’ to an rvalue of type ‘at::Tensor’, because all_reduce_ expects the type at::Tensor&.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The you can just need to create another tmp variable first.
Not a particular issue to this PR. I will do some clean up later.
9b1ad35 to
f72f4fe
Compare
d96ff80 to
e467d25
Compare
|
@desertfire Hi, could you help review again? Thanks~ |
|
@pytorchbot merge |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
Merge failedReason: 1 jobs have failed, first few of them are: trunk / macos-py3-arm64 / build Details for Dev Infra teamRaised by workflow job |
e467d25 to
c44381a
Compare
|
@pytorchbot merge |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
Implementations: 1. Move collective ops to c10d namespace, so that we can call them externally. 2. Add AOTI shims for collective ops. Testing 1. Add c10d functional UT for cpu. 2. Include the above one in cpp wrapper UT. Pull Request resolved: pytorch#154492 Approved by: https://github.com/desertfire
Implementations:
Testing
cc @H-Huang @awgu @wanchaol @fegin @fduwjj @wz337 @wconstab @d4l3k @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @ipiszy @chenyang78 @kadeng @muchulee8 @amjames @chauhang @aakhundov