-
Notifications
You must be signed in to change notification settings - Fork 25.7k
Make _Join, _Joinable, _JoinHook public #62605
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
🔗 Helpful links
💊 CI failures summary and remediationsAs of commit e0cc3ac (more details on the Dr. CI page): 💚 💚 Looks good so far! There are no failures yet. 💚 💚 This comment was automatically generated by Dr. CI (expand for details).Follow this link to opt-out of these comments for your Pull Requests.Please report bugs/suggestions to the (internal) Dr. CI Users group. |
|
Should |
|
@andwgu has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM. Added some minor comments. Thanks!
| ``throw_on_early_termination`` is enabled, both of which using an all- | ||
| reduce. | ||
| Arguments: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This renders a bit differently than other modules. E.g., below is DDP's parameters:
And this is Join's parameters:
Here is how DDP's args docstring, though I am not sure if changing Arguments to Args is sufficient. But this is a minor thing, we can fix that in followup PRs.
pytorch/torch/nn/parallel/distributed.py
Lines 385 to 388 in c4196be
| Args: | |
| module (Module): module to be parallelized | |
| device_ids (list of int or torch.device): CUDA devices. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I have not looked into it too deeply, but I think sphinx may have updated recently (#61601). When I look at DistributedDataParallel 's render from my local build, it is similar, and changing Arguments to Args does not make a difference.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oh I see
torch/distributed/algorithms/join.py
Outdated
|
|
||
| @abstractmethod | ||
| def _join_hook(self, **kwargs) -> _JoinHook: | ||
| def _join_hook(self, **kwargs) -> JoinHook: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do we need to make join_hook, join_device, and join_process_group public?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
#62605 (comment)
I was wondering that. I will make them public.
| """ | ||
| ... | ||
|
|
||
|
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I couldn't comment on line 81, so adding comments here. Any reason _join_process_group 's return type is Any? Is it because ProcessGroup is not a public type?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, that is the reason. I think for now, we have to type all process groups as Any.
torch/distributed/algorithms/join.py
Outdated
| To implement a join hook for the generic join context manager, define a | ||
| class that inherits from :class:`_JoinHook`, override ``main_hook()`` and | ||
| class that inherits from :class:`JoinHook`, override ``main_hook()`` and | ||
| ``post_hook()`` as appropriate, and override ``device()`` and |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
device() and process_group() methods are not available in JoinHook. Do you mean join_device() and join_process_group() in Joinable?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good catch. This is leftover from when device and process_group were part of JoinHook. I will fix this.
|
|
||
| Generic Join Context Manager | ||
| ============================ | ||
|
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Shall we add a short paragraph describing the purpose of this join context manager?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
TODO: when this lands, and when the tutorial lands, let's also add a link to this doc page to pointing to the tutorial page.
|
@andwgu has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator. |
Summary: Addresses: #62605 (comment) Pull Request resolved: #62785 Test Plan: I checked the render, and the link redirects as desired. Reviewed By: mrshenli Differential Revision: D30133229 Pulled By: andwgu fbshipit-source-id: baefe0d1f1b78ece44bb42e67629bc130dbf8e9a



Overview:
This removes the preceding
_from_Join,_Joinable, and_JoinHookin preparation for adding the generic join context manager tutorial (see here). This also adds a docs page, which can be linked from the tutorial. Here is a render of the docs page.Test Plan:
DistributedDataParallel.join():ZeroRedundancyOptimizer:NOTE: DDP overlap tests are failing due to a landing race. See #62592. Once the fix is landed, I will rebase, and tests should be passing.
Join: