KEMBAR78
[DTensor] Support user-supplied Generator for random ops by wconstab · Pull Request #159933 · pytorch/pytorch · GitHub
Skip to content

Conversation

@wconstab
Copy link
Contributor

@wconstab wconstab commented Aug 6, 2025

Stack from ghstack (oldest at bottom):

If the user provides a generator kwarg to a random op (e.g.
nn.init.uniform_(..., generator=my_generator)), we can still advance
that generator's state in a SPMD-global way so that each local-tensor
gets appropriate values and the generator advances to the same state as
if it had operated on the full tensor.

cc @H-Huang @awgu @wanchaol @fegin @fduwjj @wz337 @d4l3k @pragupta

If the user provides a generator kwarg to a random op (e.g.
nn.init.uniform_(..., generator=my_generator)), we can still advance
that generator's state in a SPMD-global way so that each local-tensor
gets appropriate values and the generator advances to the same state as
if it had operated on the full tensor.

[ghstack-poisoned]
@pytorch-bot
Copy link

pytorch-bot bot commented Aug 6, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/159933

Note: Links to docs will display an error until the docs builds have been completed.

✅ You can merge normally! (3 Unrelated Failures)

As of commit 7a01a13 with merge base 908c5cc (image):

UNSTABLE - The following jobs are marked as unstable, possibly due to flakiness on trunk:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

wconstab added a commit that referenced this pull request Aug 6, 2025
If the user provides a generator kwarg to a random op (e.g.
nn.init.uniform_(..., generator=my_generator)), we can still advance
that generator's state in a SPMD-global way so that each local-tensor
gets appropriate values and the generator advances to the same state as
if it had operated on the full tensor.

ghstack-source-id: b93f34f
Pull Request resolved: #159933
@pytorch-bot pytorch-bot bot added ciflow/inductor oncall: distributed Add this issue/PR to distributed oncall triage queue labels Aug 6, 2025

def _distribute_region(self, spec: DTensorSpec):
def _distribute_region(
self, spec: DTensorSpec, generator: Optional[torch.Generator]
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Missing =None

Copy link
Contributor

@XilunWu XilunWu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Overall LGTM. Let me know if you need quick unblock.

We need to explicitly tell users to initialize the generator passed in with the same seed, or we add this functionality to our manual_seed API but users still need to be aware of this.

assert maybe_user_generator is None or isinstance(
maybe_user_generator, torch.Generator
)
# maybe_user_generator = None
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

remove comment

Comment on lines 205 to 206
if g_name not in self.rng_states:
self.rng_states[g_name] = generator.get_state()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is a slight behavior divergence over using the default RNG or the user-specified RNG. Either we require users to call our manual_seed() API with the RNG passed in, or we optimistically assume users know what they're doing and are responsible for initializing the RNG with the right seed value across ranks (which is what we're doing here).

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think I like the behavior this way- I don't want to introduce a collective on every op where a user supplies an RNG.

I will add documentation to dtensor docs stating that it is the user's responsibility to ensure the passed generator has the same state on every spmd rank.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the original feature request comes from @akolesnikoff and myself. We noticed the issue specifically because we were wondering about and comparing behaviour of passing in same vs different seeded RNGs across ranks. So at least from our perspective, yes, we are intentional.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see.

passing in same vs different seeded RNGs across ranks

I assume you're trying to pass same RNGs for Data Parallel and different RNGs for Model Parallel (let me know if this is not the case).

# not because we need to keep a copy of it but because its the easiest way to make it work with the
# existing set/get APIs
g_name = str(id(generator))
if g_name not in self.rng_states:
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I will also add 2 more test cases and fix one bug here:

If the user changed the seed of their generator after using it with dtensor, we'd cache the first seed and use that forever. To prevent this I'm going to pop the temporary generator back out of self.rng_stated at the end of this context, so we add it fresh again every time.

This makes me want to do a refactor, I think we should not be storing states like this. We should probably just store a ref to our own generator and then use it.

If the user provides a generator kwarg to a random op (e.g.
nn.init.uniform_(..., generator=my_generator)), we can still advance
that generator's state in a SPMD-global way so that each local-tensor
gets appropriate values and the generator advances to the same state as
if it had operated on the full tensor.

cc H-Huang awgu wanchaol fegin fduwjj wz337 d4l3k pragupta

[ghstack-poisoned]
wconstab added a commit that referenced this pull request Aug 6, 2025
If the user provides a generator kwarg to a random op (e.g.
nn.init.uniform_(..., generator=my_generator)), we can still advance
that generator's state in a SPMD-global way so that each local-tensor
gets appropriate values and the generator advances to the same state as
if it had operated on the full tensor.

ghstack-source-id: 7b6672e
Pull Request resolved: #159933
If the user provides a generator kwarg to a random op (e.g.
nn.init.uniform_(..., generator=my_generator)), we can still advance
that generator's state in a SPMD-global way so that each local-tensor
gets appropriate values and the generator advances to the same state as
if it had operated on the full tensor.

cc H-Huang awgu wanchaol fegin fduwjj wz337 d4l3k pragupta

[ghstack-poisoned]
wconstab added a commit that referenced this pull request Aug 6, 2025
If the user provides a generator kwarg to a random op (e.g.
nn.init.uniform_(..., generator=my_generator)), we can still advance
that generator's state in a SPMD-global way so that each local-tensor
gets appropriate values and the generator advances to the same state as
if it had operated on the full tensor.

ghstack-source-id: d5ab5a2
Pull Request resolved: #159933
@lucasb-eyer
Copy link
Contributor

Thanks, Will!

If the user provides a generator kwarg to a random op (e.g.
nn.init.uniform_(..., generator=my_generator)), we can still advance
that generator's state in a SPMD-global way so that each local-tensor
gets appropriate values and the generator advances to the same state as
if it had operated on the full tensor.

cc H-Huang awgu wanchaol fegin fduwjj wz337 d4l3k pragupta

[ghstack-poisoned]
wconstab added a commit that referenced this pull request Aug 6, 2025
If the user provides a generator kwarg to a random op (e.g.
nn.init.uniform_(..., generator=my_generator)), we can still advance
that generator's state in a SPMD-global way so that each local-tensor
gets appropriate values and the generator advances to the same state as
if it had operated on the full tensor.

ghstack-source-id: d2523f2
Pull Request resolved: #159933
@wconstab wconstab added the release notes: distributed (dtensor) release notes category label Aug 6, 2025
@wconstab
Copy link
Contributor Author

wconstab commented Aug 6, 2025

I ended up implementing the version where the user-passed RNG IS mutated after it is used by DTensor.
(a) because I ran into trouble with the previous implementation that tried to cache the generator by id(generator) - I found that on some ranks, the id of the generator in the dtensor op kwargs changes, and I gave up trying to figure out why. (guessing there is a copy happening somewhere?)

(b) because i think this is the UX we want to aim for anyway.

Copy link
Contributor

@XilunWu XilunWu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this API change lgtm

Comment on lines +108 to +121
# ensure that we do not cache the 'seed' of `rng` from the first time we see it in DTensor
# TODO: we have a semantics decision to make
# There is a discontinuity between how the default RNG and a user-supplied RNG behaves with DTensor:
# (a) if the user calls `torch.manual_seed` after already using the default RNG with DTensor,
# they may be surprised that it has no effect on DTensor. They must instead call this private API
# (`torch.distributed.tensor._random._rng_tracker._manual_seed`)
# (b) If we try to match the semantics of (a) with a user-supplied RNG, they may be very surprised to find that
# their RNG object never advances its state after using it with DTensor.
# torch.distributed.tensor._random._rng_tracker._manual_seed(55)
# rng.manual_seed(55)
# torch.nn.init.uniform_(t1, 0.0, 1.0)
# torch.nn.init.uniform_(t2, 0.0, 1.0, rng)
# self.assertEqual(t1.full_tensor(), t2.full_tensor())

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

clean up this comment, maybe move to _random.py?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i think its ok to leave it here for now. i still need to get some agreement on if we're changing the default rng behavior. then i'd make another PR to do that and i can remove this TODO and enable this part of the test

Copy link
Collaborator

@wanchaol wanchaol left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lgtm

# This is a little hacky, but for any user-passed generator, we store its state under a unique key,
# not because we need to keep a copy of it but because its the easiest way to make it work with the
# existing set/get APIs. We also ensure we remove it from rng_states after each _distribute_region.
g_name = "user-passed-generator"
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wonder if it could just be str(generator) as the key?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i can try this. i think i'll land as-is and try experimenting more with this in the next PR.

I already tried using id(generator) as the key- this did not work, and I suspect we are making copies of the python wrapper at some point in our bindings or our dispatching layer, leading to the ID changing. I do notice that the str(generator) prints what looks like a memory address for the underlying CPP object, so it might indeed be more stable and fix my issue. Thanks for the suggestion!

@wconstab
Copy link
Contributor Author

wconstab commented Aug 7, 2025

@pytorchbot merge -i

@pytorch-bot pytorch-bot bot added the ciflow/trunk Trigger trunk jobs on your pull request label Aug 7, 2025
@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged while ignoring the following 3 checks: Check Labels / Check labels, Check mergeability of ghstack PR / ghstack-mergeability-check, pull / linux-jammy-py3_9-clang9-xla / test (xla, 1, 1, linux.12xlarge, unstable)

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

@github-actions github-actions bot deleted the gh/wconstab/439/head branch September 7, 2025 02:13
markc-614 pushed a commit to markc-614/pytorch that referenced this pull request Sep 17, 2025
)

If the user provides a generator kwarg to a random op (e.g.
nn.init.uniform_(..., generator=my_generator)), we can still advance
that generator's state in a SPMD-global way so that each local-tensor
gets appropriate values and the generator advances to the same state as
if it had operated on the full tensor.

Pull Request resolved: pytorch#159933
Approved by: https://github.com/fduwjj, https://github.com/XilunWu, https://github.com/wanchaol
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ciflow/inductor ciflow/trunk Trigger trunk jobs on your pull request Merged oncall: distributed Add this issue/PR to distributed oncall triage queue release notes: distributed (dtensor) release notes category

Projects

None yet

Development

Successfully merging this pull request may close these issues.

6 participants