-
Notifications
You must be signed in to change notification settings - Fork 25.7k
Refactor non-joined process computation #61555
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
💊 CI failures summary and remediationsAs of commit 6763301 (more details on the Dr. CI page and at hud.pytorch.org/pr/61555): 💚 💚 Looks good so far! There are no failures yet. 💚 💚 Preview docs built from this PR This comment was automatically generated by Dr. CI (expand for details).Follow this link to opt-out of these comments for your Pull Requests.Please report bugs/suggestions to the (internal) Dr. CI Users group. |
3dba67b
to
9b880c4
Compare
I need to revisit the pickling fix in the latest commit. Right now, it implicitly assumes that the model will not be saved/pickled inside a join context, which is arguably reasonable but should be made explicit (e.g. by adding a warning). I will see if I can figure out a way to properly save the I also think I can refactor the |
This looks great to me. And we will definitely need docstrings + tutorials + examples to show how to customize Joinable |
NVM, I don't see a reason why different join hooks wants to behave differently. |
): | ||
|
||
super(DistributedDataParallel, self).__init__() | ||
_Joinable.__init__(self) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do we need this? Looks like super().__init__()
should have already covered all parent classes?
class First(object):
def __init__(self):
super(First, self).__init__()
print("first")
class Second(object):
def __init__(self):
super(Second, self).__init__()
print("second")
class Third(First, Second):
def __init__(self):
super(Third, self).__init__()
print("third")
x = Third()
outputs
second
first
third
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Module
is not designed for cooperative inheritance since there was probably never a use case before. In particular, Module.__init__()
does not call super(Module, self).__init__()
, so super(DistributedDataParallel, self).__init__()
will call Module.__init__()
but not continue to call _Joinable.__init__()
.
Option 1: Leave it as is.
Option 2: Add super(Module, self).__init__()
to the end of Module.__init__()
:
pytorch/torch/nn/modules/module.py
Lines 250 to 266 in 8e6d899
def __init__(self): | |
""" | |
Initializes internal Module state, shared by both nn.Module and ScriptModule. | |
""" | |
torch._C._log_api_usage_once("python.nn_module") | |
self.training = True | |
self._parameters = OrderedDict() | |
self._buffers = OrderedDict() | |
self._non_persistent_buffers_set = set() | |
self._backward_hooks = OrderedDict() | |
self._is_full_backward_hook = None | |
self._forward_hooks = OrderedDict() | |
self._forward_pre_hooks = OrderedDict() | |
self._state_dict_hooks = OrderedDict() | |
self._load_state_dict_pre_hooks = OrderedDict() | |
self._modules = OrderedDict() |
Option 3: Switch the order of inheritance:
class DistributedDataParallel(_Joinable, Module):
Then, as long as we include super(_Joinable, self).__init__()
in _Joinable.__init__()
, then DistributedDataParallel.__init__()
will call Module.__init__()
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Option 1 SGTM. Let's also create an issue to track this. I am not sure why nn.Module
does not call super().__init__()
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Similar issue with torch.optim.Optimizer
. I have filed an issue to track this: #61662
joinables: List[_Joinable], | ||
enable: bool = True, | ||
throw_on_early_termination: bool = False, | ||
**kwargs, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
would I be correct that is_first_joinable
should not be passed in kwargs?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Correct. is_first_joinable
is managed internally by the join context manager. The user never needs to interact with it.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM. Thanks for coming up with the solution to consolidate join detection allreduce ops.
torch/distributed/algorithms/join.py
Outdated
This method should be called from a :class:`_Joinable` object before | ||
its per-iteration collective communications. For example, this should | ||
be called at the beginning of the forward pass in DDP. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
DDP -> :class: DistributedDataParallel
method. | ||
Returns: | ||
An async work handle for the all-reduce meant to notify the context |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We are retiring PG async work handle. It might be better to return the future object from Work.get_future()
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I did not make the change in this PR since Reducer::set_forward_pass_work_handle()
expects the async work handle and not the future.
I filed an issue to track this: #61661.
@andwgu has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator. |
Overview:
This refactors the computation on non-joined processes relating to the join context manager. The concept was inspired by a comment from @pritamdamania.
Changes:
This introduces a
_Joinable
abstract base class, which requires a_join_hook()
method and_join_device()
and_join_process_group()
property methods. Any class that we want to be compatible with the generic join context manager should inherit from_Joinable
and implement_join_hook()
,_join_device()
, and_join_process_group()
. (Thedevice
andprocess_group
information has been moved from_JoinHook
to_Joinable
.)The generic join context manager now takes in a
List[_Joinable]
instead ofList[_JoinHook]
. The motivation for this is that previously, by passing the_JoinHook
s into the context manager, the class providing a_JoinHook
can modify the context manager's behavior, but the context manager cannot modify the class's behavior. This is solved by giving the context manager a reference to the class's instance.This implementation reserves the field
_join_config
in every_Joinable
to store a_JoinConfig
instance, which holds all dynamic fields needed from the_Joinable
for the join context manager:enable
,throw_on_early_termination
, andis_first_joinable
. ("dynamic" here means that for a given_Joinable
instance, the values for those fields may change across different join context usages.) In particular, these fields are needed to implement a methodnotify_join_context()
, which encapsulates the computation performed on non-joined processes relating to the join context manager --- (1) the all-reduce to indicate that the process has not yet joined and (2) the all-reduce to check whether to throw an exception ifthrow_on_uneven_inputs=True
. The idea is that every_Joinable
class only needs to make a call tonotify_join_context()
before its per-iteration collective communications; it is a simple one-line addition.Only the first
_Joinable
instance passed into the context manager actually performs the collective communications innotify_join_context()
. In that case, the method returns an async work handle for the initial all-reduce indicating that the process not yet joined. Otherwise, the method returnsNone
. This conditional logic is handled internally without additional input from the user.New API:
Now, the example usage would look like:
Any arguments meant for a join hook (e.g.
divide_by_initial_world_size
) must be specified as keyword arguments. For example:They will be forwarded to every
_join_hook()
function via**kwargs
. This creates a clear separation between the variables needed by the context manager (enable
andthrow_on_early_termination
) and those needed by the_Joinable
class (e.g.divide_by_initial_world_size
).Recap:
After this change, the relevant information to use the generic join context manager looks like the following (omitting prefix
_
from names):C
(e.g.DistributedDataParallel
) that we want to be able to use theJoin
context.C
inherit fromJoinable
and implementjoin_hook() -> JoinHook
,join_device()
, andjoin_process_group()
.join_hook()
, we define aCJoinHook
class inheriting fromJoinHook
and implementmain_hook()
andpost_hook()
as needed.C
's per-iteration collective communications and add a call toJoin.notify_join_context()
.Joinable.__init__(self)
inC
's constructor.C.join_config
field will be used internally by the context manager. This does not affectC
's serializability.C
's join hook can be passed in as keyword arguments to the context manager:with Join([C()], arg1=..., arg2=...):
.Test Plan:
I ran the existing DDP join tests:
I ran the ZeRO join tests: