KEMBAR78
[PT] support custom all_gather and reduce_scatter comms by xunnanxu · Pull Request #155189 · pytorch/pytorch · GitHub
Skip to content

Conversation

@xunnanxu
Copy link
Contributor

@xunnanxu xunnanxu commented Jun 5, 2025

Summary:
This change introduces 2 comm override APIs: set_custom_all_gather and set_custom_reduce_scatter to allow for custom behavior respectively.

This allow users to control how the comm buffers are allocated and the exact comm implementation for flexibility.
For details, see docstring in Comm in _fsdp_api.py

Related PR:
#150564

Test Plan: CI

Differential Revision: D75714362

cc @H-Huang @awgu @wanchaol @fegin @fduwjj @wz337 @wconstab @d4l3k @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @ipiszy @chenyang78 @kadeng @muchulee8 @amjames @chauhang @aakhundov

@pytorch-bot
Copy link

pytorch-bot bot commented Jun 5, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/155189

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit 1cacb20 with merge base f79689b (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@pytorch-bot pytorch-bot bot added ciflow/inductor module: inductor oncall: distributed Add this issue/PR to distributed oncall triage queue release notes: distributed (fsdp) release notes category labels Jun 5, 2025
@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D75714362

@xunnanxu xunnanxu force-pushed the export-D75714362 branch 2 times, most recently from 2b428e6 to 569a4d6 Compare June 5, 2025 06:27
@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D75714362

1 similar comment
@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D75714362

@xunnanxu xunnanxu force-pushed the export-D75714362 branch 2 times, most recently from e806f3a to b03004b Compare June 5, 2025 07:19
@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D75714362

@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D75714362

@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D75714362

@xunnanxu xunnanxu marked this pull request as draft June 5, 2025 16:40
@xunnanxu xunnanxu force-pushed the export-D75714362 branch from 37d962b to 995e205 Compare June 6, 2025 05:22
@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D75714362

Summary:
Pull Request resolved: pytorch#155189

This change introduces 2 comm override APIs: `set_custom_all_gather` and `set_custom_reduce_scatter` to allow for custom behavior respectively.

This allow users to control how the comm buffers are allocated and the exact comm implementation for flexibility.
For details, see docstring in `Comm`, `BaseComm` in `_fsdp_api.py`

Test Plan: CI

Differential Revision: D75714362
@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D75714362

@xunnanxu xunnanxu marked this pull request as ready for review June 30, 2025 08:13
@xunnanxu xunnanxu requested review from kwen2501, lw and weifengpy June 30, 2025 08:14
Copy link
Contributor

@weifengpy weifengpy left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

appreicate your persistence to push this to the very end

@pytorch-bot pytorch-bot bot added the ciflow/trunk Trigger trunk jobs on your pull request label Jul 1, 2025
@xunnanxu
Copy link
Contributor Author

xunnanxu commented Jul 2, 2025

@pytorchbot merge

@pytorchmergebot
Copy link
Collaborator

Merge failed

Reason: This PR has internal changes and must be landed via Phabricator! Please try reimporting/rexporting the PR!

Details for Dev Infra team Raised by workflow job

@facebook-github-bot
Copy link
Contributor

@pytorchbot merge

(Initiating merge automatically since Phabricator Diff has merged)

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

xunnanxu added a commit to xunnanxu/pytorch that referenced this pull request Jul 2, 2025
…ther with (pytorch#157487)

Summary:

This is a follow up after the PR to add comm override support: pytorch#155189

The previous PR loosely checks the allocation mixin classes, which isn't really safe as the actual hook may still override the behavior.
This may lead to unnecessary confusion for no good use case. So for now we just make the 2 sets of APIs largely incompatible:
1. setting custom comms after `set_allocate_memory_from_process_group_for_comm()` is ok.
2. setting `set_allocate_memory_from_process_group_for_comm()` after custom comms is ko.

Basically `set_allocate_memory_from_process_group_for_comm` is like a drop in hammer while the `set_custom_all_gather/reduce_scatter()` are like finer-grained scalpels that require more code crafted. 

We can revisit this if there's use case in between but for now they can be largely viewed independent from each other (even tho we do share some of the underlying pieces for now, that could be subject to change and should not be exposed to end users).

Test Plan: added UT

Reviewed By: weifengpy

Differential Revision: D77681620
pytorchmergebot pushed a commit that referenced this pull request Jul 3, 2025
…ther with custom comm hooks (#157487)

Summary:
This is a follow up after the PR to add comm override support: #155189

The previous PR loosely checks the allocation mixin classes, which isn't really safe as the actual hook may still override the behavior.
This may lead to unnecessary confusion for no good use case. So for now we just make the 2 sets of APIs largely incompatible:
1. setting custom comms after `set_allocate_memory_from_process_group_for_comm()` is ok.
2. setting `set_allocate_memory_from_process_group_for_comm()` after custom comms is ko.

Basically `set_allocate_memory_from_process_group_for_comm` is like a drop in hammer while the `set_custom_all_gather/reduce_scatter()` are like finer-grained scalpels that require more code crafted.

We can revisit this if there's use case in between but for now they can be largely viewed independent from each other (even tho we do share some of the underlying pieces for now, that could be subject to change and should not be exposed to end users).

Test Plan: added UT

Differential Revision: D77681620

Pull Request resolved: #157487
Approved by: https://github.com/weifengpy
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ciflow/inductor ciflow/trunk Trigger trunk jobs on your pull request fb-exported Merged module: inductor oncall: distributed Add this issue/PR to distributed oncall triage queue release notes: distributed (fsdp) release notes category

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants