-
Notifications
You must be signed in to change notification settings - Fork 25.7k
[PT] support custom all_gather and reduce_scatter comms #155189
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/155189
Note: Links to docs will display an error until the docs builds have been completed. ✅ No FailuresAs of commit 1cacb20 with merge base f79689b ( This comment was automatically generated by Dr. CI and updates every 15 minutes. |
|
This pull request was exported from Phabricator. Differential Revision: D75714362 |
2b428e6 to
569a4d6
Compare
|
This pull request was exported from Phabricator. Differential Revision: D75714362 |
1 similar comment
|
This pull request was exported from Phabricator. Differential Revision: D75714362 |
e806f3a to
b03004b
Compare
|
This pull request was exported from Phabricator. Differential Revision: D75714362 |
b03004b to
8fc9010
Compare
8fc9010 to
8d1e2cf
Compare
|
This pull request was exported from Phabricator. Differential Revision: D75714362 |
8d1e2cf to
6a7b530
Compare
6a7b530 to
37d962b
Compare
|
This pull request was exported from Phabricator. Differential Revision: D75714362 |
37d962b to
995e205
Compare
|
This pull request was exported from Phabricator. Differential Revision: D75714362 |
94dc77f to
f9ebe84
Compare
f9ebe84 to
a292323
Compare
Summary: Pull Request resolved: pytorch#155189 This change introduces 2 comm override APIs: `set_custom_all_gather` and `set_custom_reduce_scatter` to allow for custom behavior respectively. This allow users to control how the comm buffers are allocated and the exact comm implementation for flexibility. For details, see docstring in `Comm`, `BaseComm` in `_fsdp_api.py` Test Plan: CI Differential Revision: D75714362
|
This pull request was exported from Phabricator. Differential Revision: D75714362 |
a292323 to
1cacb20
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
appreicate your persistence to push this to the very end
|
@pytorchbot merge |
Merge failedReason: This PR has internal changes and must be landed via Phabricator! Please try reimporting/rexporting the PR! Details for Dev Infra teamRaised by workflow job |
|
@pytorchbot merge (Initiating merge automatically since Phabricator Diff has merged) |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
…ther with (pytorch#157487) Summary: This is a follow up after the PR to add comm override support: pytorch#155189 The previous PR loosely checks the allocation mixin classes, which isn't really safe as the actual hook may still override the behavior. This may lead to unnecessary confusion for no good use case. So for now we just make the 2 sets of APIs largely incompatible: 1. setting custom comms after `set_allocate_memory_from_process_group_for_comm()` is ok. 2. setting `set_allocate_memory_from_process_group_for_comm()` after custom comms is ko. Basically `set_allocate_memory_from_process_group_for_comm` is like a drop in hammer while the `set_custom_all_gather/reduce_scatter()` are like finer-grained scalpels that require more code crafted. We can revisit this if there's use case in between but for now they can be largely viewed independent from each other (even tho we do share some of the underlying pieces for now, that could be subject to change and should not be exposed to end users). Test Plan: added UT Reviewed By: weifengpy Differential Revision: D77681620
…ther with custom comm hooks (#157487) Summary: This is a follow up after the PR to add comm override support: #155189 The previous PR loosely checks the allocation mixin classes, which isn't really safe as the actual hook may still override the behavior. This may lead to unnecessary confusion for no good use case. So for now we just make the 2 sets of APIs largely incompatible: 1. setting custom comms after `set_allocate_memory_from_process_group_for_comm()` is ok. 2. setting `set_allocate_memory_from_process_group_for_comm()` after custom comms is ko. Basically `set_allocate_memory_from_process_group_for_comm` is like a drop in hammer while the `set_custom_all_gather/reduce_scatter()` are like finer-grained scalpels that require more code crafted. We can revisit this if there's use case in between but for now they can be largely viewed independent from each other (even tho we do share some of the underlying pieces for now, that could be subject to change and should not be exposed to end users). Test Plan: added UT Differential Revision: D77681620 Pull Request resolved: #157487 Approved by: https://github.com/weifengpy
Summary:
This change introduces 2 comm override APIs:
set_custom_all_gatherandset_custom_reduce_scatterto allow for custom behavior respectively.This allow users to control how the comm buffers are allocated and the exact comm implementation for flexibility.
For details, see docstring in
Commin_fsdp_api.pyRelated PR:
#150564
Test Plan: CI
Differential Revision: D75714362
cc @H-Huang @awgu @wanchaol @fegin @fduwjj @wz337 @wconstab @d4l3k @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @ipiszy @chenyang78 @kadeng @muchulee8 @amjames @chauhang @aakhundov