-
Notifications
You must be signed in to change notification settings - Fork 25.7k
[DTensor] Support user-supplied Generator for random ops #159933
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
If the user provides a generator kwarg to a random op (e.g. nn.init.uniform_(..., generator=my_generator)), we can still advance that generator's state in a SPMD-global way so that each local-tensor gets appropriate values and the generator advances to the same state as if it had operated on the full tensor. [ghstack-poisoned]
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/159933
Note: Links to docs will display an error until the docs builds have been completed. ✅ You can merge normally! (3 Unrelated Failures)As of commit 7a01a13 with merge base 908c5cc ( UNSTABLE - The following jobs are marked as unstable, possibly due to flakiness on trunk:
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
If the user provides a generator kwarg to a random op (e.g. nn.init.uniform_(..., generator=my_generator)), we can still advance that generator's state in a SPMD-global way so that each local-tensor gets appropriate values and the generator advances to the same state as if it had operated on the full tensor. ghstack-source-id: b93f34f Pull Request resolved: #159933
torch/distributed/tensor/_random.py
Outdated
|
|
||
| def _distribute_region(self, spec: DTensorSpec): | ||
| def _distribute_region( | ||
| self, spec: DTensorSpec, generator: Optional[torch.Generator] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Missing =None
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Overall LGTM. Let me know if you need quick unblock.
We need to explicitly tell users to initialize the generator passed in with the same seed, or we add this functionality to our manual_seed API but users still need to be aware of this.
| assert maybe_user_generator is None or isinstance( | ||
| maybe_user_generator, torch.Generator | ||
| ) | ||
| # maybe_user_generator = None |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
remove comment
torch/distributed/tensor/_random.py
Outdated
| if g_name not in self.rng_states: | ||
| self.rng_states[g_name] = generator.get_state() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this is a slight behavior divergence over using the default RNG or the user-specified RNG. Either we require users to call our manual_seed() API with the RNG passed in, or we optimistically assume users know what they're doing and are responsible for initializing the RNG with the right seed value across ranks (which is what we're doing here).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think I like the behavior this way- I don't want to introduce a collective on every op where a user supplies an RNG.
I will add documentation to dtensor docs stating that it is the user's responsibility to ensure the passed generator has the same state on every spmd rank.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
the original feature request comes from @akolesnikoff and myself. We noticed the issue specifically because we were wondering about and comparing behaviour of passing in same vs different seeded RNGs across ranks. So at least from our perspective, yes, we are intentional.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I see.
passing in same vs different seeded RNGs across ranks
I assume you're trying to pass same RNGs for Data Parallel and different RNGs for Model Parallel (let me know if this is not the case).
torch/distributed/tensor/_random.py
Outdated
| # not because we need to keep a copy of it but because its the easiest way to make it work with the | ||
| # existing set/get APIs | ||
| g_name = str(id(generator)) | ||
| if g_name not in self.rng_states: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I will also add 2 more test cases and fix one bug here:
If the user changed the seed of their generator after using it with dtensor, we'd cache the first seed and use that forever. To prevent this I'm going to pop the temporary generator back out of self.rng_stated at the end of this context, so we add it fresh again every time.
This makes me want to do a refactor, I think we should not be storing states like this. We should probably just store a ref to our own generator and then use it.
If the user provides a generator kwarg to a random op (e.g. nn.init.uniform_(..., generator=my_generator)), we can still advance that generator's state in a SPMD-global way so that each local-tensor gets appropriate values and the generator advances to the same state as if it had operated on the full tensor. cc H-Huang awgu wanchaol fegin fduwjj wz337 d4l3k pragupta [ghstack-poisoned]
If the user provides a generator kwarg to a random op (e.g. nn.init.uniform_(..., generator=my_generator)), we can still advance that generator's state in a SPMD-global way so that each local-tensor gets appropriate values and the generator advances to the same state as if it had operated on the full tensor. ghstack-source-id: 7b6672e Pull Request resolved: #159933
If the user provides a generator kwarg to a random op (e.g. nn.init.uniform_(..., generator=my_generator)), we can still advance that generator's state in a SPMD-global way so that each local-tensor gets appropriate values and the generator advances to the same state as if it had operated on the full tensor. cc H-Huang awgu wanchaol fegin fduwjj wz337 d4l3k pragupta [ghstack-poisoned]
If the user provides a generator kwarg to a random op (e.g. nn.init.uniform_(..., generator=my_generator)), we can still advance that generator's state in a SPMD-global way so that each local-tensor gets appropriate values and the generator advances to the same state as if it had operated on the full tensor. ghstack-source-id: d5ab5a2 Pull Request resolved: #159933
|
Thanks, Will! |
If the user provides a generator kwarg to a random op (e.g. nn.init.uniform_(..., generator=my_generator)), we can still advance that generator's state in a SPMD-global way so that each local-tensor gets appropriate values and the generator advances to the same state as if it had operated on the full tensor. cc H-Huang awgu wanchaol fegin fduwjj wz337 d4l3k pragupta [ghstack-poisoned]
If the user provides a generator kwarg to a random op (e.g. nn.init.uniform_(..., generator=my_generator)), we can still advance that generator's state in a SPMD-global way so that each local-tensor gets appropriate values and the generator advances to the same state as if it had operated on the full tensor. ghstack-source-id: d2523f2 Pull Request resolved: #159933
|
I ended up implementing the version where the user-passed RNG IS mutated after it is used by DTensor. (b) because i think this is the UX we want to aim for anyway. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this API change lgtm
| # ensure that we do not cache the 'seed' of `rng` from the first time we see it in DTensor | ||
| # TODO: we have a semantics decision to make | ||
| # There is a discontinuity between how the default RNG and a user-supplied RNG behaves with DTensor: | ||
| # (a) if the user calls `torch.manual_seed` after already using the default RNG with DTensor, | ||
| # they may be surprised that it has no effect on DTensor. They must instead call this private API | ||
| # (`torch.distributed.tensor._random._rng_tracker._manual_seed`) | ||
| # (b) If we try to match the semantics of (a) with a user-supplied RNG, they may be very surprised to find that | ||
| # their RNG object never advances its state after using it with DTensor. | ||
| # torch.distributed.tensor._random._rng_tracker._manual_seed(55) | ||
| # rng.manual_seed(55) | ||
| # torch.nn.init.uniform_(t1, 0.0, 1.0) | ||
| # torch.nn.init.uniform_(t2, 0.0, 1.0, rng) | ||
| # self.assertEqual(t1.full_tensor(), t2.full_tensor()) | ||
|
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
clean up this comment, maybe move to _random.py?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
i think its ok to leave it here for now. i still need to get some agreement on if we're changing the default rng behavior. then i'd make another PR to do that and i can remove this TODO and enable this part of the test
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
lgtm
| # This is a little hacky, but for any user-passed generator, we store its state under a unique key, | ||
| # not because we need to keep a copy of it but because its the easiest way to make it work with the | ||
| # existing set/get APIs. We also ensure we remove it from rng_states after each _distribute_region. | ||
| g_name = "user-passed-generator" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I wonder if it could just be str(generator) as the key?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
i can try this. i think i'll land as-is and try experimenting more with this in the next PR.
I already tried using id(generator) as the key- this did not work, and I suspect we are making copies of the python wrapper at some point in our bindings or our dispatching layer, leading to the ID changing. I do notice that the str(generator) prints what looks like a memory address for the underlying CPP object, so it might indeed be more stable and fix my issue. Thanks for the suggestion!
|
@pytorchbot merge -i |
Merge startedYour change will be merged while ignoring the following 3 checks: Check Labels / Check labels, Check mergeability of ghstack PR / ghstack-mergeability-check, pull / linux-jammy-py3_9-clang9-xla / test (xla, 1, 1, linux.12xlarge, unstable) Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
) If the user provides a generator kwarg to a random op (e.g. nn.init.uniform_(..., generator=my_generator)), we can still advance that generator's state in a SPMD-global way so that each local-tensor gets appropriate values and the generator advances to the same state as if it had operated on the full tensor. Pull Request resolved: pytorch#159933 Approved by: https://github.com/fduwjj, https://github.com/XilunWu, https://github.com/wanchaol
Stack from ghstack (oldest at bottom):
If the user provides a generator kwarg to a random op (e.g.
nn.init.uniform_(..., generator=my_generator)), we can still advance
that generator's state in a SPMD-global way so that each local-tensor
gets appropriate values and the generator advances to the same state as
if it had operated on the full tensor.
cc @H-Huang @awgu @wanchaol @fegin @fduwjj @wz337 @d4l3k @pragupta