-
Notifications
You must be signed in to change notification settings - Fork 25.7k
Clean Up ZeRO #60285
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Clean Up ZeRO #60285
Conversation
💊 CI failures summary and remediationsAs of commit b2cbb13 (more details on the Dr. CI page and at hud.pytorch.org/pr/60285):
🕵️ 1 new failure recognized by patternsThe following CI failures do not appear to be due to upstream breakages:
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this is not needed anymore with modern pytorch, it was there initially because this functionality was not present, I think that you could remove it altogether
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
+1
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I looked into replacing _broadcast_object() with torch.distributed.broadcast_object_list() as suggested by @mrshenli. However, broadcast_object_list() yields a significant slowdown in my preliminary testing. Precisely, running test_collect_shards() on the AI AWS cluster with a world size of 4 regresses from ~6 seconds to ~18 seconds (where the affected function is consolidate_state_dict()).
I am unable to resolve the issue, but this is what I found: There are two (possibly-related) points of slowdown in broadcast_object_list():
- The implicit
torch.cuda.synchronize()resulting from moving data fromobject_tensor_sizesfrom GPU to CPU via.item()takes ~4 seconds for one of the broadcasts on ranks 1 and 2 (while generally it takes ~1 millisecond). This confuses me since_broadcast_object()follows the exact same data movement pattern, yet it does not suffer from a slow synchronization. - Unlike
_broadcast_object(),broadcast_object_list()pickles/unpickles the object via_object_to_tensor()/_tensor_to_object(), respectively._tensor_to_object()consistently takes ~4 seconds.
I am unable to see any connection between the two, but the similarity between their latencies is curious. Together, these two points causeconsolidate_state_dict()to take >4 seconds when usingbroadcast_object_list(), while it only takes a few milliseconds when using_broadcast_object().
Any help investigating this is greatly appreciated!
NB: Without pickling, the tensor allocations are 939, 939, and 1131 bytes, while with pickling, the allocations are 571, 591, and 755 bytes.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think I found the issue, and fixing this will require changing code in distributed_c10d.py. Here is my hypothesis:
Currently, broadcast_object_list() takes care to specially handle the NCCL backend case, making sure to move object_sizes_tensor and object_tensor to device memory before broadcasting. Upon receiving the broadcasted tensor, the code calls _tensor_to_object(), which loads the tensor to an object. The existing implementation uses a Pickler and Unpickler to save/load, but I have verified that the slowdown behavior is replicated using torch.save() and torch.load(), which itself uses pickling. For the remainder, I will assume a torch.save()/torch.load()-based implementation of _object_to_tensor() and _tensor_to_object().
The crux is the default behavior of torch.load() (and presumably Unpickler.load()), which deserializes (on CPU) and loads contained tensors to the device where they were saved. Suppose that the world size is 4. Then, in the received local_state_dict on rank j from the broadcast, the contained tensors are loaded to cuda:i where i != j. The time bottleneck appears to be exactly in loading these received objects to device, taking > 3.5 seconds.
The bandaid fix I have right now is changing:
_object_to_tensor(obj)to_object_to_tensor(obj, pickle=True)_tensor_to_object(tensor, tensor_size)to_tensor_to_object(tensor, tensor_size, pickle=True, device=None)
If pickle == True, then the behavior is as-is to preserve backward compatibility. Otherwise, torch.save(obj, f) and torch.load(io.BytesIO(buf), map_location=device) are used instead of the _pickler and _unpickler, respectively. Then, broadcast_object_list() passes in device=current_device to _tensor_to_object() if using NCCL backend. I have confirmed that this change removes the slowdown.
However, this solution is only a temporary fix, and I think a more thorough examination is needed. I would need to see if a solution is feasible while still using the Pickler/Unpickler (though at first glance this is not promising since the Unpickler.load() provides no options).
I do wonder how much this behavior has affected the speed of other use cases. It seems to be pretty common that the received tensor from the broadcast is from a different device -- in fact, that should be the typical case?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Discussed with @andwgu offline on this. Since there are gaps in existing broadcast_object_list(), we are going to keep the current _broadcast_object method. Later, if bandwidth allows, we can work on a followup PR to improve broadcast_object_list() and use it here.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this was for compatibility with pytorch 1.5 (this comes from fairscale, which is compatible with older pytorch, current-2 and at the time this came out it was 1.5). this is not needed anymore indeed. There were a couple of update PRs for this code which never got reviewed, it's really good that you can spend some time in that
|
looks good to me, some of these changes were overdue indeed. Some of the weird parts that you're seeing came from the retrocompatibility with pytorch 1.5, which notably had a different way of indexing the parameters in the state dict (and this code, coming from fairscale, had to be compatible with 1.5, 1.6 and 1.7 checkpoints). In general this was not kept in sync with Fairscale, and that's a shame. I would recommend that you have a look at the unit tests (fairscale's are here, not perfect either but there are probably a couple more cases being checked), they could probably be improved also, from experience it's kind of easy to overlook a corner case |
For sure, thanks for taking a look! I will read the Fairscale tests. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for cleaning up the implementation!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We have both _update_trainable and _init_local_optimizer called here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think following your other comment and removing _update_trainable() altogether in favor of only _build_param_buckets() (which checks parameters_as_bucket_view) might make this more clear. Then, the constructor will call _init_local_optimizer() to initialize the optimizer and call _build_param_buckets() to build the buckets, as the function names suggest.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we can get rid of this method entirely and just call _build_param_buckets and inside _build_param_buckets have the check for self.parameters_as_bucket_view.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm guessing this is referring to NCCL? If so we should mention this here and if we don't already have a gh issue, we should mention the lack of gather support in NCCL and add a code pointer for this to that issue.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
from the top of my head that was with older Gloo (compatibility matrix was pytorch 1.5/1.6/1.7 & NCCL/Gloo & CPU/GPU with gloo), maybe that this is worth revisiting ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We shouldn't have notes about test methods in the docs of public methods. Also, shouldn't this method be private? Do we actually expect users to call this method? If not, it should be private.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm wondering if we need this interaction between state_dict and consolidate_state_dict? Every time we call .step() can we update a variable needs_consolidation=True. Then when state_dict is called, we first consolidate the state dicts if needs_consolidateion=True, cache the result and set needs_consolidation=False. This way the user only has to worry about calling state_dict.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No, that would not work because of the collective needs. The issue here is that users might call .state_dict() on one rank only (in practice they often will), which means that you cannot consolidate at that point with any of the collective communication primitives. Something like RPC would alleviate that, but that's a bigger change, without this you need the collective consolidation before you can checkout (and doing this for every step is super slow, of course).
Functions start with |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks a lot for improving code quality!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this can be replaced by the following one?
pytorch/torch/distributed/distributed_c10d.py
Line 1703 in 5824a86
| def broadcast_object_list(object_list, src=0, group=None): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
+1
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please feel free to ignore if you already have a convenient way to inspect generated docs. I usually use this gist to serve html files locally.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
(This is prior to this PR) IIUC, a Tensor will return its id by default when used as hash key. Is it necessary to call id(p) here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
if this is unnecessary, we can also revert the var name to _param_to_index_cache`
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
(seconding that personally, but soft preference)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I examined the hashing for both Tensor and Parameter and neither seem to use id() by default.
>>> model = torch.nn.Linear(1, 1)
>>> t = model.weight.data # t: torch.Tensor
>>> d = dict()
>>> d[t] = 1
>>> d[id(t)]
Traceback (most recent call last):
File "<stdin>", line 1, in <module>
KeyError: 139809367343568
>>> p = model.weight # p: torch.nn.parameter.Parameter
>>> d = dict()
>>> d[p] = 1
>>> d[id(p)]
Traceback (most recent call last):
File "<stdin>", line 1, in <module>
KeyError: 139809367327664
Let me know if I tested this incorrectly.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for checking this. My mistake.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hold on, I am not sure the test is correct. We are passing a Tensor: int pair, but checking an int: int pair? Isn't that suppose to fail? Below is Tensor'e hash function:
Lines 603 to 606 in 8dd1dc8
| def __hash__(self): | |
| if has_torch_function_unary(self): | |
| return handle_torch_function(Tensor.__hash__, (self,), self) | |
| return id(self) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ah, I misunderstood. Yes, I believe we can switch back to _param_to_index and make it of type dict[torch.Tensor, int]. Then, instead of the key being id(p), it can directly be p.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: other parts of this file and this code base usually break long arg list into one arg per line, e.g.:
pytorch/torch/distributed/nn/api/remote_module.py
Lines 111 to 119 in 5824a86
| class _RemoteModule(nn.Module): | |
| def __init__( | |
| self, | |
| remote_device: str, | |
| module_cls: Type[nn.Module], | |
| args: Tuple = None, | |
| kwargs: Dict[str, Any] = None, | |
| _module_interface_cls: Any = None, | |
| ): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could you please post a screenshot of the rendered the docs? Thanks!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.

I changed it to be like this because any other text on the same line as * state or * param_groups appears in the grey box, but anything in the next line appears underneath in the white. The old way looked strange with differs between optimizer classes. being indented and in the white, while * state - a dict holding current optimization state. Its content was in the grey box above.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
add an error message for this assert?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
has -> have?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Very minor, but I believe the current way is correct since it should be trainability [...] has changed rather than trainability [...] have changed (trainability being singular not plural)
:)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oh I see, my bad, thanks!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I would remove that method altogether, I don't think that there's a use case for it actually (we don't expose a way to reload local state dicts)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
could be worth it adding an assert with a nice error message ? not super likely to be triggered, but in case that happens it would help debugging a fair bit
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The assumptions in the NOTE: ...s should be verified at the very beginning of the constructor (line 171). I was intending the notes for documentation purposes (for a user reading the source or for a developer planning to modify the implementation).
Do you think we should do a second pass of verifications each time this function _build_param_buckets() is called?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
sorry, a bit late to react, I meant checking for this assumption explicitly (at least the first time the buckets are built ?) and nicely error out if any issue, but maybe that it was already there indeed. Non blocking, just a brain dump
66962da to
0943803
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM!
@andwgu you might want to rebase to avoid the merge conflicts in tests.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
shall we use one arg per line format to stay consistent with other APIs in this file and other files?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
(no need to change this PR, just a heads-up) there is another recursive_to in DDP. It might be helpful to consolidate these two into one common utility function.
pytorch/torch/nn/parallel/distributed.py
Line 907 in 4887c6e
| def _recursive_to(self, inputs, target_gpu): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
(no need to change this PR, just FYI) if you would like to add a note block, this needs to be .. note::. But since this is a private API and the docs won't render, the current one also works.
pytorch/torch/nn/parallel/distributed.py
Line 216 in 4887c6e
| .. note:: |
|
@andwgu has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator. |
Overview:
Being relatively new to PyTorch and ZeRO, I found parts of the code slightly hard to follow. This change strives to clean up the
ZeroRedundancyOptimizercode inzero_redundancy_optimizer.pyby reorganizing some computations, making variable names more explicit and consistent, and unifying terminology in the documentation. The goal is for the code to be easier to extend afterwards.Changes:
state_dict(): The logic for updating the globalstate_dictwith each rank's localstate_dictis simplified and made more explicit. Notably, thedictlocal_index_to_param_idis unneeded. It mapslocal_pg["params"][i]toid(global_pg["params"][i]), so it is equivalent to make a single pass over both lists in tandem, effectively iterating overi, without a need for the explicitdict._update_trainable(): The function initializes the local optimizer if it does not exist. I am unaware of any reason for the local optimizer to be destroyed after initialization, so I moved that logic to its own function_init_local_optimizer(), which is called once in the constructor.After discussion, I removed the function
_update_trainable()itself in favor of adding a check forparameters_as_bucket_viewinbuild_param_buckets()directly.rank_local_state_dict(): This function is currently broken. It appears to be legacy and relies on the inputstate_dictto have the key"partitions". For now, I have removed it and added an issue. Is it a notable use case to want to access another rank'sstate_dictin particular (as opposed to consolidating the entire state and then accessing)?local_state_dict():After discussion, I removed the function.partition_parameters(): After discussion, I renamed the function to_partition_parameters()to mark it as private._param_to_index: After discussion, I changed the key to be the parameter itself rather than its integer ID.buckets: I renamed the data structure to_bucketsto mark it as private.Test Plan:
The test suite passes as expected (on the AI AWS cluster):
I visually inspected the generated HTML doc (as generated following this).