KEMBAR78
Refactor non-joined process computation by awgu · Pull Request #61555 · pytorch/pytorch · GitHub
Skip to content

Conversation

@awgu
Copy link
Collaborator

@awgu awgu commented Jul 12, 2021

Overview:
This refactors the computation on non-joined processes relating to the join context manager. The concept was inspired by a comment from @pritamdamania.

Changes:
This introduces a _Joinable abstract base class, which requires a _join_hook() method and _join_device() and _join_process_group() property methods. Any class that we want to be compatible with the generic join context manager should inherit from _Joinable and implement _join_hook(), _join_device(), and _join_process_group(). (The device and process_group information has been moved from _JoinHook to _Joinable.)

The generic join context manager now takes in a List[_Joinable] instead of List[_JoinHook]. The motivation for this is that previously, by passing the _JoinHooks into the context manager, the class providing a _JoinHook can modify the context manager's behavior, but the context manager cannot modify the class's behavior. This is solved by giving the context manager a reference to the class's instance.

This implementation reserves the field _join_config in every _Joinable to store a _JoinConfig instance, which holds all dynamic fields needed from the _Joinable for the join context manager: enable, throw_on_early_termination, and is_first_joinable. ("dynamic" here means that for a given _Joinable instance, the values for those fields may change across different join context usages.) In particular, these fields are needed to implement a method notify_join_context(), which encapsulates the computation performed on non-joined processes relating to the join context manager --- (1) the all-reduce to indicate that the process has not yet joined and (2) the all-reduce to check whether to throw an exception if throw_on_uneven_inputs=True. The idea is that every _Joinable class only needs to make a call to notify_join_context() before its per-iteration collective communications; it is a simple one-line addition.

Only the first _Joinable instance passed into the context manager actually performs the collective communications in notify_join_context(). In that case, the method returns an async work handle for the initial all-reduce indicating that the process not yet joined. Otherwise, the method returns None. This conditional logic is handled internally without additional input from the user.

New API:
Now, the example usage would look like:

ddp_model = DistributedDataParallel(...)
zero_optim = ZeroRedundancyOptimizer(ddp_model.parameters(), ...)
with _Join([ddp_model, zero_optim]):
    ...

Any arguments meant for a join hook (e.g. divide_by_initial_world_size) must be specified as keyword arguments. For example:

with _Join([ddp_model, zero_optim], divide_by_initial_world_size=False):
    ...

They will be forwarded to every _join_hook() function via **kwargs. This creates a clear separation between the variables needed by the context manager (enable and throw_on_early_termination) and those needed by the _Joinable class (e.g. divide_by_initial_world_size).

Recap:
After this change, the relevant information to use the generic join context manager looks like the following (omitting prefix _ from names):

  • Suppose we have a class C (e.g. DistributedDataParallel) that we want to be able to use the Join context.
  • We make C inherit from Joinable and implement join_hook() -> JoinHook, join_device(), and join_process_group().
  • To implement join_hook(), we define a CJoinHook class inheriting from JoinHook and implement main_hook() and post_hook() as needed.
  • We locate a place before C's per-iteration collective communications and add a call to Join.notify_join_context().
  • We call Joinable.__init__(self) in C's constructor.
  • The C.join_config field will be used internally by the context manager. This does not affect C's serializability.
  • Run time arguments for C's join hook can be passed in as keyword arguments to the context manager: with Join([C()], arg1=..., arg2=...):.

Test Plan:
I ran the existing DDP join tests:

touch /tmp/barrier && TEMP_DIR="/tmp" BACKEND="nccl" WORLD_SIZE="2" gpurun python test/distributed/test_distributed_fork.py -- TestDistBackendWithFork.test_ddp_uneven_inputs TestDistBackendWithFork.test_ddp_uneven_inputs_stop_iteration_sync_bn TestDistBackendWithFork.test_ddp_grad_div_uneven_inputs TestDistBackendWithFork.test_ddp_uneven_input_join_disable TestDistBackendWithFork.test_ddp_uneven_input_exception

I ran the ZeRO join tests:

gpurun4 python test/distributed/optim/test_zero_redundancy_optimizer.py TestZeroRedundancyOptimizerDistributed.test_zero_join_gpu TestZeroRedundancyOptimizerDistributed.test_zero_join_cpu

@facebook-github-bot facebook-github-bot added oncall: distributed Add this issue/PR to distributed oncall triage queue cla signed labels Jul 12, 2021
@facebook-github-bot
Copy link
Contributor

facebook-github-bot commented Jul 12, 2021

💊 CI failures summary and remediations

As of commit 6763301 (more details on the Dr. CI page and at hud.pytorch.org/pr/61555):


💚 💚 Looks good so far! There are no failures yet. 💚 💚


Preview docs built from this PR

This comment was automatically generated by Dr. CI (expand for details).Follow this link to opt-out of these comments for your Pull Requests.

Please report bugs/suggestions to the (internal) Dr. CI Users group.

Click here to manually regenerate this comment.

@awgu awgu force-pushed the generic_join branch 3 times, most recently from 3dba67b to 9b880c4 Compare July 13, 2021 00:27
@awgu awgu requested a review from mrshenli July 13, 2021 01:05
@awgu
Copy link
Collaborator Author

awgu commented Jul 13, 2021

I need to revisit the pickling fix in the latest commit. Right now, it implicitly assumes that the model will not be saved/pickled inside a join context, which is arguably reasonable but should be made explicit (e.g. by adding a warning). I will see if I can figure out a way to properly save the _join_config field and load it (the issue is that the process group is not pickle-able).


I also think I can refactor the process_group and device information so that it is not included in both the _JoinHook and the _JoinConfig. I will do that tomorrow.

@mrshenli
Copy link
Contributor

This implementation reserves the field _join_config in every _Joinable to store a _JoinConfig instance, which holds all dynamic fields needed from the _Joinable for the join context manager: enable, throw_on_early_termination, and is_first_joinable. ("dynamic" here means that for a given _Joinable instance, the values for those fields may change across different join context usages.) In particular, these fields are needed to implement a method notify_join_context(), which encapsulates the computation performed on non-joined processes relating to the join context manager --- (1) the all-reduce to indicate that the process has not yet joined and (2) the all-reduce to check whether to throw an exception if throw_on_uneven_inputs=True. The idea is that every _Joinable class only needs to make a call to notify_join_context() before its per-iteration collective communications; it is a simple one-line addition.

This looks great to me. And we will definitely need docstrings + tutorials + examples to show how to customize Joinable

@mrshenli
Copy link
Contributor

mrshenli commented Jul 13, 2021

They will be forwarded to every _join_hook() function via **kwargs. This creates a clear separation between the variables needed by the context manager (enable and throw_on_early_termination) and those needed by the _Joinable class (e.g. divide_by_initial_world_size).

What if different Joinable needs different value for those args?

NVM, I don't see a reason why different join hooks wants to behave differently.

):

super(DistributedDataParallel, self).__init__()
_Joinable.__init__(self)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need this? Looks like super().__init__() should have already covered all parent classes?

class First(object):
    def __init__(self):
        super(First, self).__init__()
        print("first")

class Second(object):
    def __init__(self):
        super(Second, self).__init__()
        print("second")

class Third(First, Second):
    def __init__(self):
        super(Third, self).__init__()
        print("third")

x = Third()

outputs

second
first
third

Copy link
Collaborator Author

@awgu awgu Jul 13, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Module is not designed for cooperative inheritance since there was probably never a use case before. In particular, Module.__init__() does not call super(Module, self).__init__(), so super(DistributedDataParallel, self).__init__() will call Module.__init__() but not continue to call _Joinable.__init__().

Option 1: Leave it as is.

Option 2: Add super(Module, self).__init__() to the end of Module.__init__():

def __init__(self):
"""
Initializes internal Module state, shared by both nn.Module and ScriptModule.
"""
torch._C._log_api_usage_once("python.nn_module")
self.training = True
self._parameters = OrderedDict()
self._buffers = OrderedDict()
self._non_persistent_buffers_set = set()
self._backward_hooks = OrderedDict()
self._is_full_backward_hook = None
self._forward_hooks = OrderedDict()
self._forward_pre_hooks = OrderedDict()
self._state_dict_hooks = OrderedDict()
self._load_state_dict_pre_hooks = OrderedDict()
self._modules = OrderedDict()

Option 3: Switch the order of inheritance:

class DistributedDataParallel(_Joinable, Module):

Then, as long as we include super(_Joinable, self).__init__() in _Joinable.__init__(), then DistributedDataParallel.__init__() will call Module.__init__().

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Option 1 SGTM. Let's also create an issue to track this. I am not sure why nn.Module does not call super().__init__().

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Similar issue with torch.optim.Optimizer. I have filed an issue to track this: #61662

joinables: List[_Joinable],
enable: bool = True,
throw_on_early_termination: bool = False,
**kwargs,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

would I be correct that is_first_joinable should not be passed in kwargs?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Correct. is_first_joinable is managed internally by the join context manager. The user never needs to interact with it.

@awgu awgu marked this pull request as ready for review July 13, 2021 23:27
Copy link
Contributor

@mrshenli mrshenli left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM. Thanks for coming up with the solution to consolidate join detection allreduce ops.

This method should be called from a :class:`_Joinable` object before
its per-iteration collective communications. For example, this should
be called at the beginning of the forward pass in DDP.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

DDP -> :class: DistributedDataParallel

method.
Returns:
An async work handle for the all-reduce meant to notify the context
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We are retiring PG async work handle. It might be better to return the future object from Work.get_future().

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I did not make the change in this PR since Reducer::set_forward_pass_work_handle() expects the async work handle and not the future.

I filed an issue to track this: #61661.

@facebook-github-bot
Copy link
Contributor

@andwgu has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

@facebook-github-bot
Copy link
Contributor

@andwgu merged this pull request in 57feb35.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

cla signed Merged oncall: distributed Add this issue/PR to distributed oncall triage queue

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants