KEMBAR78
[cpp wrapper] add AOTI shim for collective ops by Valentine233 · Pull Request #154492 · pytorch/pytorch · GitHub
Skip to content

Conversation

@Valentine233
Copy link
Collaborator

@Valentine233 Valentine233 commented May 28, 2025

Implementations:

  1. Move collective ops to c10d namespace, so that we can call them externally.
  2. Add AOTI shims for collective ops.

Testing

  1. Add c10d functional UT for cpu.
  2. Include the above one in cpp wrapper UT.

cc @H-Huang @awgu @wanchaol @fegin @fduwjj @wz337 @wconstab @d4l3k @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @ipiszy @chenyang78 @kadeng @muchulee8 @amjames @chauhang @aakhundov

@pytorch-bot
Copy link

pytorch-bot bot commented May 28, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/154492

Note: Links to docs will display an error until the docs builds have been completed.

✅ You can merge normally! (1 Unrelated Failure)

As of commit c44381a with merge base 04178d3 (image):

UNSTABLE - The following job is marked as unstable, possibly due to flakiness on trunk:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@Valentine233 Valentine233 added the topic: not user facing topic category label May 28, 2025
@Valentine233 Valentine233 marked this pull request as draft May 28, 2025 08:09
@github-actions
Copy link
Contributor

Attention! PyTorch one of the C-stable API file was changed

You MUST NOT change existing function declarations in this, as this header defines a stable C ABI. If you need to change the signature for a function, introduce a new v2 version of the function and modify code generation to target the new version of the function.


Caused by:

@Valentine233 Valentine233 marked this pull request as ready for review June 4, 2025 01:13
from torch.testing._internal.inductor_utils import HAS_CPU


def load_test_module(name):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

has this function be defined else where? Can we reuse that?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks. The change is removed according to Baobin's suggestion.

return (
type(node) == ir._CollectiveKernel and (op is None or node.op_overload is op)
isinstance(node, ir._CollectiveKernel)
and not isinstance(node, ir._WaitKernel)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

From the definition of ir._WaitKernel which is a subclass of ir._CollectiveKernel, why it will be excluded here?

Copy link
Collaborator Author

@Valentine233 Valentine233 Jun 4, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Because ir._WaitKernel does not have any constant_args like other collective kernels.

The function is_collective is used here https://github.com/pytorch/pytorch/blob/main/test/inductor/test_snode_runtime.py#L244. If the check is True, it will finally call _get_group_size_by_name(node.constant_args[-1]) https://github.com/pytorch/pytorch/blob/main/torch/_inductor/comm_analysis.py#L73, which raises an error for ir._WaitKernel.
Ditto for the change of get_collective_group_size in common_analysis.py.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Then shouldn't get_collective_group_size handle handles _WaitKernel differently? (disclaimer: I am not sure if _WaitKernel has a concept of group size)

Copy link
Collaborator Author

@Valentine233 Valentine233 Jun 16, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

_WaitKernel is a special _CollectiveKernel, which does not have group_size (generate from group_name) like others.

For example, you can see the diff between schemas of wait_tensor and all_reduce:
"wait_tensor(Tensor tensor) -> Tensor"
"all_reduce(Tensor input, str reduce_op, str group_name) -> Tensor"

@jcaip jcaip added the triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module label Jun 5, 2025
@Valentine233
Copy link
Collaborator Author

@desertfire @yifuwang Could you have a look at this PR? Thanks!

AtenTensorHandle qScaleAndZeros,
AtenTensorHandle* ret0);

AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu__c10d_functional_all_reduce_(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wish there is a way to autogen these...

Copy link
Collaborator Author

@Valentine233 Valentine233 Jun 16, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks! As collective OPs are not defined under torch.ops.aten, their AOTI shims cannot be automatically generated. Do you have any suggestions for how to autogen these kernels?

testdir = Path(__file__).absolute().parent.parent
with mock.patch("sys.path", [*sys.path, str(testdir)]):
return SourceFileLoader(
name, str(testdir / f"{name.replace('.', '/')}.py")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks fragile. Can you simply test cpp_wrapper in test_c10d_functional_native.py?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks and modified!

kernel = self.op_overload
self.cpp_kernel_name = kernel._schema.name
if cpp_kernel_name is not None:
self.cpp_kernel_name = cpp_kernel_name
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do you need to explicitly pass in cpp_kernel_name here? cpp_wrapper_cpu.py should have taken care of the schema name to C shim name conversion.

Copy link
Collaborator Author

@Valentine233 Valentine233 Jun 16, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for your comment.

In function set_cpp_kernel_name, the optional param cpp_kernel_name is not used. I suppose we need to take it into account when this is explicitly pass.

As collective OPs are not defined under torch.ops.aten, their AOTI shims cannot be automatically generated. So here I need to pass the cpp_kernel_name like what this does: https://github.com/pytorch/pytorch/blob/main/torch/_inductor/mkldnn_ir.py#L1265.

return (
type(node) == ir._CollectiveKernel and (op is None or node.op_overload is op)
isinstance(node, ir._CollectiveKernel)
and not isinstance(node, ir._WaitKernel)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Then shouldn't get_collective_group_size handle handles _WaitKernel differently? (disclaimer: I am not sure if _WaitKernel has a concept of group size)

AtenTensorHandle* ret0) {
AOTI_TORCH_CONVERT_EXCEPTION_TO_ERROR_CODE({
auto tmp_result = c10d::all_reduce_(
*tensor_handle_to_tensor_pointer(inp), reduce_op, group_name);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please use resolve_tensor_dispatch_flags instead of tensor_handle_to_tensor_pointer.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The modification would raise an error: cannot bind non-const lvalue reference of type ‘at::Tensor&’ to an rvalue of type ‘at::Tensor’, because all_reduce_ expects the type at::Tensor&.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The you can just need to create another tmp variable first.

Not a particular issue to this PR. I will do some clean up later.

@Valentine233 Valentine233 force-pushed the collective_c_shim branch 2 times, most recently from d96ff80 to e467d25 Compare June 19, 2025 01:30
@Valentine233
Copy link
Collaborator Author

@desertfire Hi, could you help review again? Thanks~

@Valentine233
Copy link
Collaborator Author

@pytorchbot merge

@pytorch-bot pytorch-bot bot added the ciflow/trunk Trigger trunk jobs on your pull request label Jun 24, 2025
@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

@pytorchmergebot
Copy link
Collaborator

Merge failed

Reason: 1 jobs have failed, first few of them are: trunk / macos-py3-arm64 / build

Details for Dev Infra team Raised by workflow job

@Valentine233
Copy link
Collaborator Author

@pytorchbot merge

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

skarjala pushed a commit to skarjala/pytorch that referenced this pull request Jun 25, 2025
Implementations:
1. Move collective ops to c10d namespace, so that we can call them externally.
2. Add AOTI shims for collective ops.

Testing
1. Add c10d functional UT for cpu.
2. Include the above one in cpp wrapper UT.

Pull Request resolved: pytorch#154492
Approved by: https://github.com/desertfire
@github-actions github-actions bot deleted the collective_c_shim branch July 25, 2025 02:20
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ciflow/inductor ciflow/trunk Trigger trunk jobs on your pull request Merged module: inductor oncall: distributed Add this issue/PR to distributed oncall triage queue open source release notes: distributed (c10d) release notes category release notes: inductor (aoti) topic: not user facing topic category triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Projects

None yet

Development

Successfully merging this pull request may close these issues.

6 participants