KEMBAR78
Clean Up ZeRO by awgu · Pull Request #60285 · pytorch/pytorch · GitHub
Skip to content

Conversation

@awgu
Copy link
Collaborator

@awgu awgu commented Jun 18, 2021

Overview:
Being relatively new to PyTorch and ZeRO, I found parts of the code slightly hard to follow. This change strives to clean up the ZeroRedundancyOptimizer code in zero_redundancy_optimizer.py by reorganizing some computations, making variable names more explicit and consistent, and unifying terminology in the documentation. The goal is for the code to be easier to extend afterwards.

Changes:

  1. state_dict(): The logic for updating the global state_dict with each rank's local state_dict is simplified and made more explicit. Notably, the dict local_index_to_param_id is unneeded. It maps local_pg["params"][i] to id(global_pg["params"][i]), so it is equivalent to make a single pass over both lists in tandem, effectively iterating over i, without a need for the explicit dict.
  2. _update_trainable(): The function initializes the local optimizer if it does not exist. I am unaware of any reason for the local optimizer to be destroyed after initialization, so I moved that logic to its own function _init_local_optimizer(), which is called once in the constructor.
    After discussion, I removed the function _update_trainable() itself in favor of adding a check for parameters_as_bucket_view in build_param_buckets() directly.
  3. rank_local_state_dict(): This function is currently broken. It appears to be legacy and relies on the input state_dict to have the key "partitions". For now, I have removed it and added an issue. Is it a notable use case to want to access another rank's state_dict in particular (as opposed to consolidating the entire state and then accessing)?
  4. local_state_dict(): After discussion, I removed the function.
  5. partition_parameters(): After discussion, I renamed the function to _partition_parameters() to mark it as private.
  6. _param_to_index: After discussion, I changed the key to be the parameter itself rather than its integer ID.
  7. buckets: I renamed the data structure to _buckets to mark it as private.
  8. Terminology: I tried to reduce the set of terms being used instead of juggling a number of synonyms. In particular, I made an effort to distinguish between "local" and "global" and to make names more indicative of typing.
  9. Style: Per the PyTorch contributing guide, I made all docstrings abide by the 80 character limit, except for the one line showing the example ZeRO usage. Some code lines violate the limit for readability. Also, I unified some of the minor stylistic usages out of habit.

Test Plan:
The test suite passes as expected (on the AI AWS cluster):

gpurun python test/distributed/optim/test_zero_redundancy_optimizer.py

I visually inspected the generated HTML doc (as generated following this).

@facebook-github-bot facebook-github-bot added oncall: distributed Add this issue/PR to distributed oncall triage queue cla signed labels Jun 18, 2021
@awgu awgu requested a review from mrshenli June 18, 2021 17:36
@facebook-github-bot
Copy link
Contributor

facebook-github-bot commented Jun 18, 2021

💊 CI failures summary and remediations

As of commit b2cbb13 (more details on the Dr. CI page and at hud.pytorch.org/pr/60285):


  • 2/2 failures possibly* introduced in this PR
    • 1/2 non-scanned failure(s)

🕵️ 1 new failure recognized by patterns

The following CI failures do not appear to be due to upstream breakages:

See CircleCI build pytorch_xla_linux_bionic_py3_6_clang9_test (1/1)

Step: "Run tests" (full log | diagnosis details | 🔁 rerun)

Jun 23 03:40:04 AssertionError: False is not tr... was 1.0 (1.0 vs. 0.0), which occurred at index 0.
Jun 23 03:40:04 ----------------------------------------------------------------------
Jun 23 03:40:04 Traceback (most recent call last):
Jun 23 03:40:04   File "/opt/conda/lib/python3.6/site-packages/torch/testing/_internal/common_device_type.py", line 397, in instantiated_test
Jun 23 03:40:04     result = test_fn(self, *args)
Jun 23 03:40:04   File "/var/lib/jenkins/workspace/xla/test/../../test/test_view_ops.py", line 458, in test_transpose_inplace_view
Jun 23 03:40:04     self.assertEqual(t[1, 0], v[0, 1])
Jun 23 03:40:04   File "/var/lib/jenkins/workspace/xla/test/pytorch_test_base.py", line 605, in assertEqual
Jun 23 03:40:04     return DeviceTypeTestBase.assertEqual(self, x, y, *args, **kwargs)
Jun 23 03:40:04   File "/opt/conda/lib/python3.6/site-packages/torch/testing/_internal/common_utils.py", line 1407, in assertEqual
Jun 23 03:40:04     super().assertTrue(result, msg=self._get_assert_msg(msg, debug_msg=debug_msg))
Jun 23 03:40:04 AssertionError: False is not true : Tensors failed to compare as equal!With rtol=0.001 and atol=0.001, found 1 element(s) (out of 1) whose difference(s) exceeded the margin of error (including 0 nan comparisons). The greatest difference was 1.0 (1.0 vs. 0.0), which occurred at index 0.
Jun 23 03:40:04 
Jun 23 03:40:04 ----------------------------------------------------------------------
Jun 23 03:40:04 Ran 138 tests in 3.384s
Jun 23 03:40:04 
Jun 23 03:40:04 FAILED (failures=2, skipped=102)
Jun 23 03:40:04 
Jun 23 03:40:04 Generating XML reports...
Jun 23 03:40:04 Generated XML report: test-reports/python-unittest/test.......test.test_view_ops/TEST-TestViewOpsXLA-20210623034000.xml
Jun 23 03:40:04 + cleanup
Jun 23 03:40:04 + retcode=1

This comment was automatically generated by Dr. CI (expand for details).Follow this link to opt-out of these comments for your Pull Requests.

Please report bugs/suggestions to the (internal) Dr. CI Users group.

Click here to manually regenerate this comment.

@awgu awgu requested a review from blefaudeux June 18, 2021 17:36
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is not needed anymore with modern pytorch, it was there initially because this functionality was not present, I think that you could remove it altogether

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

+1

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I looked into replacing _broadcast_object() with torch.distributed.broadcast_object_list() as suggested by @mrshenli. However, broadcast_object_list() yields a significant slowdown in my preliminary testing. Precisely, running test_collect_shards() on the AI AWS cluster with a world size of 4 regresses from ~6 seconds to ~18 seconds (where the affected function is consolidate_state_dict()).

I am unable to resolve the issue, but this is what I found: There are two (possibly-related) points of slowdown in broadcast_object_list():

  1. The implicit torch.cuda.synchronize() resulting from moving data from object_tensor_sizes from GPU to CPU via .item() takes ~4 seconds for one of the broadcasts on ranks 1 and 2 (while generally it takes ~1 millisecond). This confuses me since _broadcast_object() follows the exact same data movement pattern, yet it does not suffer from a slow synchronization.
  2. Unlike _broadcast_object(), broadcast_object_list() pickles/unpickles the object via _object_to_tensor()/_tensor_to_object(), respectively. _tensor_to_object() consistently takes ~4 seconds.
    I am unable to see any connection between the two, but the similarity between their latencies is curious. Together, these two points cause consolidate_state_dict() to take >4 seconds when using broadcast_object_list(), while it only takes a few milliseconds when using _broadcast_object().

Any help investigating this is greatly appreciated!

NB: Without pickling, the tensor allocations are 939, 939, and 1131 bytes, while with pickling, the allocations are 571, 591, and 755 bytes.

Copy link
Collaborator Author

@awgu awgu Jun 22, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think I found the issue, and fixing this will require changing code in distributed_c10d.py. Here is my hypothesis:

Currently, broadcast_object_list() takes care to specially handle the NCCL backend case, making sure to move object_sizes_tensor and object_tensor to device memory before broadcasting. Upon receiving the broadcasted tensor, the code calls _tensor_to_object(), which loads the tensor to an object. The existing implementation uses a Pickler and Unpickler to save/load, but I have verified that the slowdown behavior is replicated using torch.save() and torch.load(), which itself uses pickling. For the remainder, I will assume a torch.save()/torch.load()-based implementation of _object_to_tensor() and _tensor_to_object().

The crux is the default behavior of torch.load() (and presumably Unpickler.load()), which deserializes (on CPU) and loads contained tensors to the device where they were saved. Suppose that the world size is 4. Then, in the received local_state_dict on rank j from the broadcast, the contained tensors are loaded to cuda:i where i != j. The time bottleneck appears to be exactly in loading these received objects to device, taking > 3.5 seconds.

The bandaid fix I have right now is changing:

  • _object_to_tensor(obj) to _object_to_tensor(obj, pickle=True)
  • _tensor_to_object(tensor, tensor_size) to _tensor_to_object(tensor, tensor_size, pickle=True, device=None)

If pickle == True, then the behavior is as-is to preserve backward compatibility. Otherwise, torch.save(obj, f) and torch.load(io.BytesIO(buf), map_location=device) are used instead of the _pickler and _unpickler, respectively. Then, broadcast_object_list() passes in device=current_device to _tensor_to_object() if using NCCL backend. I have confirmed that this change removes the slowdown.

However, this solution is only a temporary fix, and I think a more thorough examination is needed. I would need to see if a solution is feasible while still using the Pickler/Unpickler (though at first glance this is not promising since the Unpickler.load() provides no options).

I do wonder how much this behavior has affected the speed of other use cases. It seems to be pretty common that the received tensor from the broadcast is from a different device -- in fact, that should be the typical case?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Discussed with @andwgu offline on this. Since there are gaps in existing broadcast_object_list(), we are going to keep the current _broadcast_object method. Later, if bandwidth allows, we can work on a followup PR to improve broadcast_object_list() and use it here.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this was for compatibility with pytorch 1.5 (this comes from fairscale, which is compatible with older pytorch, current-2 and at the time this came out it was 1.5). this is not needed anymore indeed. There were a couple of update PRs for this code which never got reviewed, it's really good that you can spend some time in that

@blefaudeux
Copy link
Contributor

looks good to me, some of these changes were overdue indeed. Some of the weird parts that you're seeing came from the retrocompatibility with pytorch 1.5, which notably had a different way of indexing the parameters in the state dict (and this code, coming from fairscale, had to be compatible with 1.5, 1.6 and 1.7 checkpoints). In general this was not kept in sync with Fairscale, and that's a shame. I would recommend that you have a look at the unit tests (fairscale's are here, not perfect either but there are probably a couple more cases being checked), they could probably be improved also, from experience it's kind of easy to overlook a corner case

@awgu
Copy link
Collaborator Author

awgu commented Jun 18, 2021

looks good to me, some of these changes were overdue indeed. Some of the weird parts that you're seeing came from the retrocompatibility with pytorch 1.5, which notably had a different way of indexing the parameters in the state dict (and this code, coming from fairscale, had to be compatible with 1.5, 1.6 and 1.7 checkpoints). In general this was not kept in sync with Fairscale, and that's a shame. I would recommend that you have a look at the unit tests (fairscale's are here, not perfect either but there are probably a couple more cases being checked), they could probably be improved also, from experience it's kind of easy to overlook a corner case

For sure, thanks for taking a look! I will read the Fairscale tests.

Copy link
Contributor

@pritamdamania87 pritamdamania87 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for cleaning up the implementation!

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We have both _update_trainable and _init_local_optimizer called here?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think following your other comment and removing _update_trainable() altogether in favor of only _build_param_buckets() (which checks parameters_as_bucket_view) might make this more clear. Then, the constructor will call _init_local_optimizer() to initialize the optimizer and call _build_param_buckets() to build the buckets, as the function names suggest.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we can get rid of this method entirely and just call _build_param_buckets and inside _build_param_buckets have the check for self.parameters_as_bucket_view.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm guessing this is referring to NCCL? If so we should mention this here and if we don't already have a gh issue, we should mention the lack of gather support in NCCL and add a code pointer for this to that issue.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

from the top of my head that was with older Gloo (compatibility matrix was pytorch 1.5/1.6/1.7 & NCCL/Gloo & CPU/GPU with gloo), maybe that this is worth revisiting ?

Comment on lines 335 to 338
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We shouldn't have notes about test methods in the docs of public methods. Also, shouldn't this method be private? Do we actually expect users to call this method? If not, it should be private.

Comment on lines +504 to +491
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm wondering if we need this interaction between state_dict and consolidate_state_dict? Every time we call .step() can we update a variable needs_consolidation=True. Then when state_dict is called, we first consolidate the state dicts if needs_consolidateion=True, cache the result and set needs_consolidation=False. This way the user only has to worry about calling state_dict.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, that would not work because of the collective needs. The issue here is that users might call .state_dict() on one rank only (in practice they often will), which means that you cannot consolidate at that point with any of the collective communication primitives. Something like RPC would alleviate that, but that's a bigger change, without this you need the collective consolidation before you can checkout (and doing this for every step is super slow, of course).

@mrshenli
Copy link
Contributor

I visually inspected the generated HTML doc (as generated following this). Only the "public" functions (i.e. those without any preceding _) are rendered, so I am not sure how to verify the other functions.

Functions start with _ are consider private functions, and we won't expose them to API docs. Their docstring can still help future developers to understand the code.

Copy link
Contributor

@mrshenli mrshenli left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks a lot for improving code quality!

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this can be replaced by the following one?

def broadcast_object_list(object_list, src=0, group=None):

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

+1

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please feel free to ignore if you already have a convenient way to inspect generated docs. I usually use this gist to serve html files locally.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(This is prior to this PR) IIUC, a Tensor will return its id by default when used as hash key. Is it necessary to call id(p) here?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if this is unnecessary, we can also revert the var name to _param_to_index_cache`

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(seconding that personally, but soft preference)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I examined the hashing for both Tensor and Parameter and neither seem to use id() by default.

>>> model = torch.nn.Linear(1, 1)
>>> t = model.weight.data  # t: torch.Tensor
>>> d = dict()
>>> d[t] = 1
>>> d[id(t)]
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
KeyError: 139809367343568
>>> p = model.weight  # p: torch.nn.parameter.Parameter
>>> d = dict()
>>> d[p] = 1
>>> d[id(p)]
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
KeyError: 139809367327664

Let me know if I tested this incorrectly.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for checking this. My mistake.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hold on, I am not sure the test is correct. We are passing a Tensor: int pair, but checking an int: int pair? Isn't that suppose to fail? Below is Tensor'e hash function:

pytorch/torch/_tensor.py

Lines 603 to 606 in 8dd1dc8

def __hash__(self):
if has_torch_function_unary(self):
return handle_torch_function(Tensor.__hash__, (self,), self)
return id(self)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, I misunderstood. Yes, I believe we can switch back to _param_to_index and make it of type dict[torch.Tensor, int]. Then, instead of the key being id(p), it can directly be p.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: other parts of this file and this code base usually break long arg list into one arg per line, e.g.:

class _RemoteModule(nn.Module):
def __init__(
self,
remote_device: str,
module_cls: Type[nn.Module],
args: Tuple = None,
kwargs: Dict[str, Any] = None,
_module_interface_cls: Any = None,
):

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you please post a screenshot of the rendered the docs? Thanks!

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Screen Shot 2021-06-20 at 9 02 40 PM
I changed it to be like this because any other text on the same line as * state or * param_groups appears in the grey box, but anything in the next line appears underneath in the white. The old way looked strange with differs between optimizer classes. being indented and in the white, while * state - a dict holding current optimization state. Its content was in the grey box above.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

add an error message for this assert?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

has -> have?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Very minor, but I believe the current way is correct since it should be trainability [...] has changed rather than trainability [...] have changed (trainability being singular not plural)
:)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh I see, my bad, thanks!

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would remove that method altogether, I don't think that there's a use case for it actually (we don't expose a way to reload local state dicts)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

could be worth it adding an assert with a nice error message ? not super likely to be triggered, but in case that happens it would help debugging a fair bit

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The assumptions in the NOTE: ...s should be verified at the very beginning of the constructor (line 171). I was intending the notes for documentation purposes (for a user reading the source or for a developer planning to modify the implementation).

Do you think we should do a second pass of verifications each time this function _build_param_buckets() is called?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sorry, a bit late to react, I meant checking for this assumption explicitly (at least the first time the buckets are built ?) and nicely error out if any issue, but maybe that it was already there indeed. Non blocking, just a brain dump

@awgu awgu force-pushed the zero_cleanup branch 2 times, most recently from 66962da to 0943803 Compare June 22, 2021 20:43
Copy link
Contributor

@mrshenli mrshenli left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM!

@andwgu you might want to rebase to avoid the merge conflicts in tests.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

shall we use one arg per line format to stay consistent with other APIs in this file and other files?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(no need to change this PR, just a heads-up) there is another recursive_to in DDP. It might be helpful to consolidate these two into one common utility function.

def _recursive_to(self, inputs, target_gpu):

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(no need to change this PR, just FYI) if you would like to add a note block, this needs to be .. note::. But since this is a private API and the docs won't render, the current one also works.

@facebook-github-bot
Copy link
Contributor

@andwgu has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

@facebook-github-bot
Copy link
Contributor

@andwgu merged this pull request in f0e4e4b.

@awgu awgu deleted the zero_cleanup branch February 3, 2022 00:45
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

cla signed Merged oncall: distributed Add this issue/PR to distributed oncall triage queue

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants