-
Notifications
You must be signed in to change notification settings - Fork 25.7k
Add overlap with DDP to ZeRO (two approaches) #62157
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
🔗 Helpful links
💊 CI failures summary and remediationsAs of commit ab54c52 (more details on the Dr. CI page): 💚 💚 Looks good so far! There are no failures yet. 💚 💚 This comment was automatically generated by Dr. CI (expand for details).Follow this link to opt-out of these comments for your Pull Requests.Please report bugs/suggestions to the (internal) Dr. CI Users group. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Would I be correct if I assume this does not have impact on non-overlapping ZeRO perfs and correctness? If so, I don't have concerns for landing this. We can land and improve in followup PRs.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This map also lives in DistributedOptimizer. Shall we consolidate these two into a single source of truth?
pytorch/torch/distributed/optim/optimizer.py
Lines 196 to 209 in 73f1e2d
| # dict to map a user passed in optimizer_class to a functional | |
| # optimizer class if we have already defined inside the | |
| # distributed.optim package, this is so that we hide the | |
| # functional optimizer to user and still provide the same API. | |
| functional_optim_map = { | |
| optim.Adagrad: _FunctionalAdagrad, | |
| optim.Adam: _FunctionalAdam, | |
| optim.AdamW: _FunctionalAdamW, | |
| optim.SGD: _FunctionalSGD, | |
| optim.Adadelta: _FunctionalAdadelta, | |
| optim.RMSprop: _FunctionalRMSprop, | |
| optim.Rprop: _FunctionalRprop, | |
| optim.Adamax: _FunctionalAdamax, | |
| } |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This looks fine for now. But let's add a followup PR to automatically do the translation for users, similar to how DistributedOptimizer handles this.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit (can be done in follow PRs): curious, callbacks on buckets will fire in order, so this can be simplified to a list of GradBucket objects?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For a given rank, we only need to store references to the GradBuckets that the rank uses to update. As such, even though the buckets fire in order, the "keys" for this data structure are only a subset of all indices.
If the total number of buckets is not large, using a list and either storing all buckets or storing Nones in unused positions may be faster than storing a dict, but I used a dict for now.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ah, I see, makes sense.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit (can be done in followup PRs): looks like each Future in this dict is just read once, so this cache probably won't save on Python-CPP context switch cost. Maybe this can be deleted, and directly use bucket.get_future()?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In general, I would try to avoid the state vars unless that would give us either perf gain or considerable code simplification. As more vars might make it harder to understand and maintain.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I may be misunderstanding, but the reason I stored this dict is to have a reference to all relevant Futures when running the hook for the last bucket to be able to call wait() on them. bucket is of type GradBucket, which does not offer a get_future() method. Rather, c10d.ProcessGroup.work objects provide a get_future() method. However, at the moment, this seems a bit backward. The hook passed into hook_then_zero_step() (most typically allreduce_hook()) returns a Future, where that Future is exactly gotten via dist.all_reduce().get_future(). Either way, I do not see how to get around storing at least something per bucket to update that gives access to the bucket's corresponding Future to give the rank access to the Future when executing the hook for the last bucket.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
curious, why this is device_to_params instead of device_to_buckets?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good point -- I will change it. It was from before I created the class _DDPBucket and instead had a List[torch.Tensor] as the inner-most element in the dict value.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
(can be done in followup PRs) This conceptually overlaps with DDP's Gradient Bucket. We probably want to let DDP expose its GradBuckets instead of creating new classes?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think that it may be possible to get rid of this class altogether with how DDP's GradBucket is exposed currently. I can investigate this in a follow-up PR.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is still the allreduce future, right? Meaning that waiting on this future (by DDP backward()) won't guarantee that the ZeRO broadcast is done? We probably need to return torch.futures.collect_all(all_zero_broadcast_futures) for the last bucket instead.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Among these reviews, this is the most important one, as it might have impact on correctness.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Resolved offline.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@andwgu Do you mind describing how this was resolved?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
https://github.com/andwgu/pytorch/blob/ab54c52fa54dac017ff6c9eadae49a97aa117432/torch/distributed/algorithms/ddp_comm_hooks/ddp_zero_hook.py#L175
That line should provide the necessary synchronization already. (The wait() on the returned future in finalize_backward() will not guarantee that the ZeRO broadcasts are done, but these additional wait()s will.)
Also, as a side note, the torch.futures.collect_all(all_zero_broadcast_futures) would change the hook semantics since it would then return a future for the broadcasts instead of a future for the bucket all-reduce.
|
@andwgu has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is a great amount of work! Amazing job! Thanks a lot for putting significant effort to explore, diagnose, optimize, different approaches and deliver high-quality code and docs. The code look good to me.
Not stamping yet, mainly because we need tests ;)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
state (object): ... ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fixed.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actually, I was reading this: https://stackoverflow.com/questions/39817081/typing-any-vs-object
Even though the type annotation for state in register_comm_hook() in distributed.py is object, I wonder if that should be Any instead. If I am understanding correctly, if you type annotate state with object, then to appease a type checker, every operation involving state must work for any general object type. In other words, annotating with object is actually restrictive rather than all-encompassing.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
curious, is this the conventional way to document function signature? Curious, why this does not match Callable[[Any, dist.GradBucket], torch.futures.Future]?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good catch. I accidentally used the type signature style from the language Standard ML.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
:meth:step -> :meth:step` ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fixed.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
print a warning if parameters_as_bucket_view=True?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Added a warning using logging.warning() in that case.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
add some explanation that ZeroRedundancyOptimizer is fully initialized after DDP buckets are rebuilt?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Added.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
assert overlap_with_ddp == False?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Added.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
IIUC, the recommended new API for this is param.set_(). Let's do that change in a followup PR, as there might be subtle differences.
https://pytorch.org/docs/stable/generated/torch.Tensor.set_.html#torch.Tensor.set_
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
One issue with using set_() is that we must set requires_grad=False and later reset it to True since otherwise we get the error saying that a leaf Variable that requires grad is being used in an in-place operation.
More importantly, I am unable to get it to manage the strides correctly. It seems like if you specify offset, then you must also specify size and stride as arguments. The method uses that stride when accessing the shared underlying storage. However, we want param to preserve its non-flattened stride, while sharing the storage with the flattened bucket. This suggests the need for additional view manipulation (which I have not figured out how to get correct). Hence, as far as I can see, I do not think that using set_() will be any simpler for this copying.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do we need to error out if the backend is not NCCL? (as gloo might hang today)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Added check.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this is a bit different from the above hook, which also checks assert overlap_info.status == _OverlapStatus.UNINITIALIZED
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This means there is no optimizer step in the first 2 iterations and gradients are discarded? Do we need to mention this in the doc, or do you plan to change this in followup PRs?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I added warnings to mention this in both hook functions and ZeroRedundancyOptimizer.
The motivation for this is so that we can avoid initializing a non-functional local optimizer and then re-initializing a functional optimizer soon thereafter. Even if we did that, we would have to transfer optimizer state correctly.
We do not know until DDP buckets are built which parameters a given rank is assigned to update. I have not yet come up with a clean solution that allows not throwing away the first two inputs.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Small note: The first three inputs are unused if static_graph=True for the DDP model.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
(can be done in followup PRs)
We do not know until DDP buckets are built which parameters a given rank is assigned to update. I have not yet come up with a clean solution that allows not throwing away the first two inputs.
Since ZeRO's overlap states knows whether the buckets are rebuilt or not, is it possible that, in the first 2-3 iterations, ZeroRedundancyOptimizer.step() still performs comp + comm the same way as no-overlap cases? And only after the buckets are rebuilt, ZeroRedundancyOptimizer.step() becomes no-op?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, we can do that. Originally, I was averse to the idea of re-initializing the local optimizer, but upon further thought, providing an equivalent functionality to the user is definitely more important.
What we can do is assume parameter_as_bucket_view=False behavior until buckets are rebuilt. This means that the parameters will be partitioned according to the normal sorted-greedy algorithm and a local optimizer will be initialized as normal at construction time. Then, when buckets are rebuilt, the parameter partitioning will change and the local optimizer will be reinitialized. I can do this in a follow-up PR.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I made these uniform across the file to use two colons. I checked the render, and the warning still appears as desired.
|
Unit tests should be coming tomorrow. |
I added unit tests. The tests are based on those by @rohan-varma:
However, I chose to not de-duplicating them since there were still some notable differences in the setup and since we may want to keep ZeroRedundancyOptimizer tests fully separate from general distributed tests.
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This will not actually be rendered since the method is private, but just in case, I think the docstring render expects a newline before any bulleted list.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Similar thing here.
|
@andwgu has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator. |
1 similar comment
|
@andwgu has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This rename is to prevent implicit shadowing of the method's argument params_per_rank.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If not, this should error out, as it might hang with Gloo?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
NVM, I saw, there are TODOs to cover CPU tests as well.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Curious, why do we need a barrier here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good point. Removed.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
sort?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sorted.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do we have to explicitly import distributed_c10d? If the goal was to use Backend.NCCL, will dist.Backend.NCCL work?
https://pytorch.org/docs/master/distributed.html#torch.distributed.Backend
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Changed to dist.get_backend() and dist.Backend.NCCL.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's cite issue #62300 here
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Added.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
(can be done in followup PRs)
We do not know until DDP buckets are built which parameters a given rank is assigned to update. I have not yet come up with a clean solution that allows not throwing away the first two inputs.
Since ZeRO's overlap states knows whether the buckets are rebuilt or not, is it possible that, in the first 2-3 iterations, ZeroRedundancyOptimizer.step() still performs comp + comm the same way as no-overlap cases? And only after the buckets are rebuilt, ZeroRedundancyOptimizer.step() becomes no-op?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM. Left some minor comments. Stamp to unblock. Please wait for tests.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Shall we add a check here to error out if this assumption was violated? (maybe by remembering the last bucket index)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is a good idea. Currently, I am a bit uncertain on how to do this. I will address it in a follow-up PR.
Remembering the last bucket index might be circular since the way we currently tell if it is the last one is by its index. If the buckets fire in order, then we should be guaranteed that the CPU part of each hook executes in order right? In that case, we can reuse the bucket_indices_seen data structures used in the other hook and make sure that the current bucket index is 1 more than the previous index seen (clearing it after each iteration).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's add a short comment to explain why this cannot be chained as a callback to the returned future.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Added:
# NOTE: This should not be chained as a callback to the last bucket's
# all-reduce future since that would add synchronization that delays
# all optimizer computation to wait for that last all-reduce
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit (can be done in followup PR, and this is just my personal preference, please feel free to ignore): The _OverlapInfo type is not owned by this hook function and lives a bit far away from this file. It might be better to wrap its state manipulations into member method in _OverlapInfo, instead of directly manipulating them here and above.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this definitely makes sense. I will try to refactor this in a follow-up PR.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ah, I see, makes sense.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Test failures on ROCm are real. If this is ROCm only, it's OK to skip tests on ROCm for now.
21:51:57 ======================================================================
21:51:57 ERROR [100.077s]: test_ddp_with_zero_step_interleaved_parity_gpu (__main__.TestZeroRedundancyOptimizerDistributed)
21:51:57 ----------------------------------------------------------------------
21:51:57 Traceback (most recent call last):
21:51:57 File "/opt/conda/lib/python3.6/site-packages/torch/testing/_internal/common_distributed.py", line 406, in wrapper
21:51:57 self._join_processes(fn)
21:51:57 File "/opt/conda/lib/python3.6/site-packages/torch/testing/_internal/common_distributed.py", line 608, in _join_processes
21:51:57 self._check_return_codes(elapsed_time)
21:51:57 File "/opt/conda/lib/python3.6/site-packages/torch/testing/_internal/common_distributed.py", line 653, in _check_return_codes
21:51:57 raise RuntimeError('Process {} terminated or timed out after {} seconds'.format(i, elapsed_time))
21:51:57 RuntimeError: Process 0 terminated or timed out after 100.057626247406 seconds
21:51:57
21:51:57 ======================================================================
21:51:57 ERROR [100.061s]: test_ddp_with_zero_step_parity_gpu (__main__.TestZeroRedundancyOptimizerDistributed)
21:51:57 ----------------------------------------------------------------------
21:51:57 Traceback (most recent call last):
21:51:57 File "/opt/conda/lib/python3.6/site-packages/torch/testing/_internal/common_distributed.py", line 406, in wrapper
21:51:57 self._join_processes(fn)
21:51:57 File "/opt/conda/lib/python3.6/site-packages/torch/testing/_internal/common_distributed.py", line 608, in _join_processes
21:51:57 self._check_return_codes(elapsed_time)
21:51:57 File "/opt/conda/lib/python3.6/site-packages/torch/testing/_internal/common_distributed.py", line 653, in _check_return_codes
21:51:57 raise RuntimeError('Process {} terminated or timed out after {} seconds'.format(i, elapsed_time))
21:51:57 RuntimeError: Process 0 terminated or timed out after 100.03943800926208 seconds
|
Lint failures are real. Please fix before landing. |
I will ask @rohan-varma why he skips if ROCm for his tests:
It is likely a similar issue. |
Lint issues were resulting from the |
|
@andwgu has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator. |
|
@andwgu has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator. |
| ddp: DistributedDataParallel, | ||
| zero: ZeroRedundancyOptimizer, | ||
| ) -> Callable[[Any, dist.GradBucket], torch.futures.Future]: | ||
| ) -> Callable[[Any, dist.GradBucket], torch.futures.Future[torch.Tensor]]: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I ran all of the tests via ci-all, and for a few of the tests, I saw:
ValueError: Communication hook: return annotation should be torch.futures.Future[torch.Tensor].
As a result, I changed torch.futures.Future -> torch.futures.Future[torch.Tensor].
Codecov Report
@@ Coverage Diff @@
## master #62157 +/- ##
==========================================
- Coverage 60.32% 60.21% -0.12%
==========================================
Files 660 661 +1
Lines 86440 86710 +270
==========================================
+ Hits 52148 52215 +67
- Misses 34292 34495 +203 |
**Overview:** This refactors some commonalities between the two approaches to overlapping DDP with ZeRO. This also partially addresses this comment: #62157 (comment) **Test Plan:** ``` gpurun4 python test/distributed/optim/test_zero_redundancy_optimizer.py ``` Differential Revision: [D30058543](https://our.internmc.facebook.com/intern/diff/D30058543) [ghstack-poisoned]
Overview:
This adds two approaches to overlapping
DistributedDataParallel.backward()withZeroRedundancyOptimizer.step()by providing two hook constructors:hook_with_zero_step()andhook_with_zero_step_interleaved(). The former waits for all backward computation to finish before starting optimizer computation, while the latter launches a partial optimizer computation using the contents of a gradient bucket once that bucket's all-reduce completes. The two approaches each suffer from their own weaknesses, and which one to use depends on the specific hardware configuration.Both approaches can share changes to
ZeroRedundancyOptimizer. A user should passoverlap_with_ddp=TruetoZeroRedundancyOptimizer, construct a DDP communication hook using eitherhook_with_zero_step()orhook_with_zero_step_interleaved(), and register that communication hook.ZeroRedundancyOptimizer.step()should still be called in the training loop, though the optimizer computation and communication will be offloaded to originate from the communication hook. Currently, the first two iterations are vacuous, meaning they do not result in parameter updates and the inputs are ignored. This is required to finalize the DDP bucket strategy and to then initialize theZeroRedundancyOptimizer's local optimizer based on that bucketing.Test Plan:
The existing
ZeroRedundancyOptimizertests pass, and new unit tests for both hooks pass:(removed for now due to flakiness in CI -- under investigation, could possibly be similar Gloo issue as withtest_ddp_with_zero_step_parity_cpuhook_with_zero_step_interleaved())test_ddp_with_zero_step_parity_gputest_ddp_with_zero_step_interleaved_parity_gpuThese were tested on the AI AWS cluster.
An analogous
test_ddp_with_zero_step_interleaved_parity_cpuis missing due to existing bugs with Gloo. See #62302.Both approaches have been verified using an internal accuracy benchmark.