-
Notifications
You must be signed in to change notification settings - Fork 25.7k
[dtensor][random] allow user to manual_seed different seed on device mesh; only sync RNG state in WORLD when manual_seed has not been called #141223
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
…mesh; only sync RNG state in WORLD when manual_seed has not been called [ghstack-poisoned]
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/141223
Note: Links to docs will display an error until the docs builds have been completed. ✅ You can merge normally! (1 Unrelated Failure)As of commit 366bd1c with merge base 6a22cae ( FLAKY - The following job failed but was likely due to flakiness present on trunk:
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
| When calling this function, :func:`manual_seed` must be called from all ranks of the | ||
| default ``ProcessGroup`` even if some ranks may not be a part of the ``device_mesh``, | ||
| with the same ``seed`` value. | ||
| :func:`manual_seed` does not check the ``seed`` value correctness. Users must |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
for a future PR- I wonder if we want to put 'manual_seed' into the tensor/__init__.py namespace?
along with that change, i think we need a tutorial and a composability test that hopefully share the same code snippet and test that we can initialize 8gpu with pp + 2D spmd and show a toy module with sufficient complexity (e.g. has replicated bias and sharded param), gets different weights on expected ranks and same weights on expected ranks.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I agree with and like everything @wconstab said!
|
@XilunWu FYI i got a PR on torchtitan side which uses this PR. It crashes on something maybe not supported well inside OffsetBasedRNGTracker. Maybe this is a regression after not using TensorParallelRNGTracker for TorchTitan? Pytorch branch: this PR Error: |
… on device mesh; only sync RNG state in WORLD when manual_seed has not been called"
**Summary**
This PR proposes 3 changes to DTensor RNG management:
1. DTensor allows users to eagerly initialize the RNG tracker by calling `torch.distributed.tensor._random.manual_seed`.
2. DTensor `manual_seed` no longer checks the integrity of the `seed` argument. Users are responsible for setting the same seed on all ranks within an SPMD group, but if there are multiple separate SPMD groups (e.g. across pipeline stages), users should set a _different_ seed for each SPMD group. For cases like Pipeline Parallel, users can set different initial seed for pipelining stages by calling
```
world_mesh = init_device_mesh(
device_type="cuda",
mesh_shape=(2, 2, 2),
mesh_dim_names=("pp", "dp", "tp"),
)
pp_mesh = world_mesh["pp"]
pp_rank = pp_mesh.get_local_rank()
spmd_mesh = world_mesh["dp", "tp"]._flatten("spmd")
torch.distributed.tensor._random.manual_seed(123+pp_rank, spmd_mesh)
```
In other word, if users want to call `torch.distributed.tensor._random.manual_seed`, they will be responsible for passing in the right value and DTensor won't perform any checks on it. If the current rank is not a part of the mesh, it will use the current device RNG state to initialize.
3. `OffsetBasedRNGTracker` still performs RNG state synchronization by broadcasting the RNG state on rank 0 to `WORLD`. However, calling `torch.distributed.tensor._random.manual_seed` is an exception. In this case, no broadcast will happen.
**Motivation**
tl;dr
1. Lazily initializing DTensor RNG tracker causes hang in non-SPMD code such as Pipeline Parallel.
2. Users may want to set different seed on ranks in one device mesh.
3. We want to keep the old behavior if users prefer not curating the RNG state and want to have DTensor take care of it.
see detail in #140301
**Test**
`pytest test/distributed/_tensor/test_random_ops.py`
`pytest test/distributed/tensor/parallel/test_tp_random_state.py`
cc wanchaol tianyu-l wz337 d4l3k H-Huang awgu kwen2501 fegin fduwjj wconstab c-p-i-occ c-p-i-o
[ghstack-poisoned]
… on device mesh; only sync RNG state in WORLD when manual_seed has not been called"
**Summary**
This PR proposes 3 changes to DTensor RNG management:
1. DTensor allows users to eagerly initialize the RNG tracker by calling `torch.distributed.tensor._random.manual_seed`.
2. DTensor `manual_seed` no longer checks the integrity of the `seed` argument. Users are responsible for setting the same seed on all ranks within an SPMD group, but if there are multiple separate SPMD groups (e.g. across pipeline stages), users should set a _different_ seed for each SPMD group. For cases like Pipeline Parallel, users can set different initial seed for pipelining stages by calling
```
world_mesh = init_device_mesh(
device_type="cuda",
mesh_shape=(2, 2, 2),
mesh_dim_names=("pp", "dp", "tp"),
)
pp_mesh = world_mesh["pp"]
pp_rank = pp_mesh.get_local_rank()
spmd_mesh = world_mesh["dp", "tp"]._flatten("spmd")
torch.distributed.tensor._random.manual_seed(123+pp_rank, spmd_mesh)
```
In other word, if users want to call `torch.distributed.tensor._random.manual_seed`, they will be responsible for passing in the right value and DTensor won't perform any checks on it. If the current rank is not a part of the mesh, it will use the current device RNG state to initialize.
3. `OffsetBasedRNGTracker` still performs RNG state synchronization by broadcasting the RNG state on rank 0 to `WORLD`. However, calling `torch.distributed.tensor._random.manual_seed` is an exception. In this case, no broadcast will happen.
**Motivation**
tl;dr
1. Lazily initializing DTensor RNG tracker causes hang in non-SPMD code such as Pipeline Parallel.
2. Users may want to set different seed on ranks in one device mesh.
3. We want to keep the old behavior if users prefer not curating the RNG state and want to have DTensor take care of it.
see detail in #140301
**Test**
`pytest test/distributed/_tensor/test_random_ops.py`
`pytest test/distributed/tensor/parallel/test_tp_random_state.py`
cc wanchaol tianyu-l wz337 d4l3k H-Huang awgu kwen2501 fegin fduwjj wconstab c-p-i-occ c-p-i-o
[ghstack-poisoned]
… on device mesh; only sync RNG state in WORLD when manual_seed has not been called"
**Summary**
This PR proposes 3 changes to DTensor RNG management:
1. DTensor allows users to eagerly initialize the RNG tracker by calling `torch.distributed.tensor._random.manual_seed`.
2. DTensor `manual_seed` no longer checks the integrity of the `seed` argument. Users are responsible for setting the same seed on all ranks within an SPMD group, but if there are multiple separate SPMD groups (e.g. across pipeline stages), users should set a _different_ seed for each SPMD group. For cases like Pipeline Parallel, users can set different initial seed for pipelining stages by calling
```
world_mesh = init_device_mesh(
device_type="cuda",
mesh_shape=(2, 2, 2),
mesh_dim_names=("pp", "dp", "tp"),
)
pp_mesh = world_mesh["pp"]
pp_rank = pp_mesh.get_local_rank()
spmd_mesh = world_mesh["dp", "tp"]._flatten("spmd")
torch.distributed.tensor._random.manual_seed(123+pp_rank, spmd_mesh)
```
In other word, if users want to call `torch.distributed.tensor._random.manual_seed`, they will be responsible for passing in the right value and DTensor won't perform any checks on it. If the current rank is not a part of the mesh, it will use the current device RNG state to initialize.
3. `OffsetBasedRNGTracker` still performs RNG state synchronization by broadcasting the RNG state on rank 0 to `WORLD`. However, calling `torch.distributed.tensor._random.manual_seed` is an exception. In this case, no broadcast will happen.
**Motivation**
tl;dr
1. Lazily initializing DTensor RNG tracker causes hang in non-SPMD code such as Pipeline Parallel.
2. Users may want to set different seed on ranks in one device mesh.
3. We want to keep the old behavior if users prefer not curating the RNG state and want to have DTensor take care of it.
see detail in #140301
**Test**
`pytest test/distributed/_tensor/test_random_ops.py`
`pytest test/distributed/tensor/parallel/test_tp_random_state.py`
cc wanchaol tianyu-l wz337 d4l3k H-Huang awgu kwen2501 fegin fduwjj wconstab c-p-i-occ c-p-i-o
[ghstack-poisoned]
|
successfully run test in pytorch/torchtitan#689 after rebasing the PR on top. |
… on device mesh; only sync RNG state in WORLD when manual_seed has not been called"
**Summary**
This PR proposes 3 changes to DTensor RNG management:
1. DTensor allows users to eagerly initialize the RNG tracker by calling `torch.distributed.tensor._random.manual_seed`.
2. DTensor `manual_seed` no longer checks the integrity of the `seed` argument. Users are responsible for setting the same seed on all ranks within an SPMD group, but if there are multiple separate SPMD groups (e.g. across pipeline stages), users should set a _different_ seed for each SPMD group. For cases like Pipeline Parallel, users can set different initial seed for pipelining stages by calling
```
world_mesh = init_device_mesh(
device_type="cuda",
mesh_shape=(2, 2, 2),
mesh_dim_names=("pp", "dp", "tp"),
)
pp_mesh = world_mesh["pp"]
pp_rank = pp_mesh.get_local_rank()
spmd_mesh = world_mesh["dp", "tp"]._flatten("spmd")
torch.distributed.tensor._random.manual_seed(123+pp_rank, spmd_mesh)
```
In other word, if users want to call `torch.distributed.tensor._random.manual_seed`, they will be responsible for passing in the right value and DTensor won't perform any checks on it. If the current rank is not a part of the mesh, it will use the current device RNG state to initialize.
3. `OffsetBasedRNGTracker` still performs RNG state synchronization by broadcasting the RNG state on rank 0 to `WORLD`. However, calling `torch.distributed.tensor._random.manual_seed` is an exception. In this case, no broadcast will happen.
**Motivation**
tl;dr
1. Lazily initializing DTensor RNG tracker causes hang in non-SPMD code such as Pipeline Parallel.
2. Users may want to set different seed on ranks in one device mesh.
3. We want to keep the old behavior if users prefer not curating the RNG state and want to have DTensor take care of it.
see detail in #140301
**Test**
`pytest test/distributed/_tensor/test_random_ops.py`
`pytest test/distributed/tensor/parallel/test_tp_random_state.py`
cc wanchaol tianyu-l wz337 d4l3k H-Huang awgu kwen2501 fegin fduwjj wconstab c-p-i-occ c-p-i-o
[ghstack-poisoned]
Just curious: is this part necessary? or can we just pass
This sounds wrong to me- if the user passes a mesh that is supposed to be the 'spmd mesh', and the current rank is not a part of that mesh, shouldn't we raise an error? |
| ) | ||
|
|
||
| # the current rank is in mesh | ||
| if device_mesh.get_coordinate() is not None: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
should we assert this instead of silent-no-op-ing it?
if we require users to pass in a device mesh that is defined as the 'spmd world', then it is wholly reasonable to make it error out if the mesh does not contain our rank.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
approving as long as you add the runtime error in manual_seed for missing ranks.
… on device mesh; only sync RNG state in WORLD when manual_seed has not been called"
**Summary**
This PR proposes 3 changes to DTensor RNG management:
1. DTensor allows users to eagerly initialize the RNG tracker by calling `torch.distributed.tensor._random.manual_seed`.
2. DTensor `manual_seed` no longer checks the integrity of the `seed` argument. Users are responsible for setting the same seed on all ranks within an SPMD group, but if there are multiple separate SPMD groups (e.g. across pipeline stages), users should set a _different_ seed for each SPMD group. For cases like Pipeline Parallel, users can set different initial seed for pipelining stages by calling
```
world_mesh = init_device_mesh(
device_type="cuda",
mesh_shape=(2, 2, 2),
mesh_dim_names=("pp", "dp", "tp"),
)
pp_mesh = world_mesh["pp"]
pp_rank = pp_mesh.get_local_rank()
spmd_mesh = world_mesh["dp", "tp"]._flatten("spmd") # this flattening is only needed if you need to call collective over this mesh
torch.distributed.tensor._random.manual_seed(123+pp_rank, spmd_mesh)
```
In other word, if users want to call `torch.distributed.tensor._random.manual_seed`, they will be responsible for passing in the right value and DTensor won't perform any checks on it. If the current rank is not a part of the mesh, it will use the current device RNG state to initialize.
3. `OffsetBasedRNGTracker` still performs RNG state synchronization by broadcasting the RNG state on rank 0 to `WORLD`. However, calling `torch.distributed.tensor._random.manual_seed` is an exception. In this case, no broadcast will happen.
**Motivation**
tl;dr
1. Lazily initializing DTensor RNG tracker causes hang in non-SPMD code such as Pipeline Parallel.
2. Users may want to set different seed on ranks in one device mesh.
3. We want to keep the old behavior if users prefer not curating the RNG state and want to have DTensor take care of it.
see detail in #140301
**Test**
`pytest test/distributed/_tensor/test_random_ops.py`
`pytest test/distributed/tensor/parallel/test_tp_random_state.py`
cc wanchaol tianyu-l wz337 d4l3k H-Huang awgu kwen2501 fegin fduwjj wconstab c-p-i-occ c-p-i-o
[ghstack-poisoned]
| seed (int): The desired seed. | ||
| device_mesh (:class:`DeviceMesh`): The device mesh to set the seed. | ||
| device_mesh (:class:`DeviceMesh`): The device mesh to set the seed. It is | ||
| required that the ``device_mesh`` include the calling rank. This is |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: "include the calling rank" is too weak a statement.
It is required that 'device_mesh' contains the superset of all meshes that will be used with this DTensor. It would be good to define 'spmd world' and then say that this mesh must be the spmd_world.
… avoid overflow (#141532) **Summary** DTensor RNG code raises error if the seed passed in is beyong `torch.int64` range (e.g. `torch.tensor([2**64-1])` raises error). The solution is to specify the `dtype=torch.uint64` in the `torch.tensor()` call. Pull Request resolved: #141532 Approved by: https://github.com/wconstab ghstack dependencies: #141731, #141220, #141223
…mesh; only sync RNG state in WORLD when manual_seed has not been called (pytorch#141223) **Summary** This PR proposes 4 changes to DTensor RNG management: 1. DTensor allows users to eagerly initialize the RNG tracker by calling `torch.distributed.tensor._random.manual_seed`. 2. DTensor `manual_seed` no longer checks the integrity of the `seed` argument. Users are responsible for setting the same seed on all ranks within an SPMD group, but if there are multiple separate SPMD groups (e.g. across pipeline stages), users should set a _different_ seed for each SPMD group. For cases like Pipeline Parallel, users can set different initial seed for pipelining stages by calling ``` world_mesh = init_device_mesh( device_type="cuda", mesh_shape=(2, 2, 2), mesh_dim_names=("pp", "dp", "tp"), ) pp_mesh = world_mesh["pp"] pp_rank = pp_mesh.get_local_rank() spmd_mesh = world_mesh["dp", "tp"]._flatten("spmd") # this flattening is only needed if you need to call collective over this mesh torch.distributed.tensor._random.manual_seed(123+pp_rank, spmd_mesh) ``` In other word, if users want to call `torch.distributed.tensor._random.manual_seed`, they will be responsible for passing in the right value and DTensor won't perform any checks on it. If the current rank is not a part of the mesh, it will use the current device RNG state to initialize. 3. `OffsetBasedRNGTracker` still performs RNG state synchronization by broadcasting the RNG state on rank 0 to `WORLD`. However, calling `torch.distributed.tensor._random.manual_seed` is an exception. In this case, no broadcast will happen. 4. Enforce that the `manual_seed` call only accept "full mesh" i.e. the DTensor RNG state on every rank must be set through the call. This makes sure that no rank has its RNG state left uninitialized and the SPMD ranks have their RNG state synchronous. **Motivation** tl;dr 1. Lazily initializing DTensor RNG tracker causes hang in non-SPMD code such as Pipeline Parallel. 2. Users may want to set different seed on ranks in one device mesh. 3. We want to keep the old behavior if users prefer not curating the RNG state and want to have DTensor take care of it. see detail in pytorch#140301 **Test** `pytest test/distributed/_tensor/test_random_ops.py` `pytest test/distributed/tensor/parallel/test_tp_random_state.py` Pull Request resolved: pytorch#141223 Approved by: https://github.com/wconstab ghstack dependencies: pytorch#141731, pytorch#141220
… avoid overflow (pytorch#141532) **Summary** DTensor RNG code raises error if the seed passed in is beyong `torch.int64` range (e.g. `torch.tensor([2**64-1])` raises error). The solution is to specify the `dtype=torch.uint64` in the `torch.tensor()` call. Pull Request resolved: pytorch#141532 Approved by: https://github.com/wconstab ghstack dependencies: pytorch#141731, pytorch#141220, pytorch#141223
…mesh; only sync RNG state in WORLD when manual_seed has not been called (pytorch#141223) **Summary** This PR proposes 4 changes to DTensor RNG management: 1. DTensor allows users to eagerly initialize the RNG tracker by calling `torch.distributed.tensor._random.manual_seed`. 2. DTensor `manual_seed` no longer checks the integrity of the `seed` argument. Users are responsible for setting the same seed on all ranks within an SPMD group, but if there are multiple separate SPMD groups (e.g. across pipeline stages), users should set a _different_ seed for each SPMD group. For cases like Pipeline Parallel, users can set different initial seed for pipelining stages by calling ``` world_mesh = init_device_mesh( device_type="cuda", mesh_shape=(2, 2, 2), mesh_dim_names=("pp", "dp", "tp"), ) pp_mesh = world_mesh["pp"] pp_rank = pp_mesh.get_local_rank() spmd_mesh = world_mesh["dp", "tp"]._flatten("spmd") # this flattening is only needed if you need to call collective over this mesh torch.distributed.tensor._random.manual_seed(123+pp_rank, spmd_mesh) ``` In other word, if users want to call `torch.distributed.tensor._random.manual_seed`, they will be responsible for passing in the right value and DTensor won't perform any checks on it. If the current rank is not a part of the mesh, it will use the current device RNG state to initialize. 3. `OffsetBasedRNGTracker` still performs RNG state synchronization by broadcasting the RNG state on rank 0 to `WORLD`. However, calling `torch.distributed.tensor._random.manual_seed` is an exception. In this case, no broadcast will happen. 4. Enforce that the `manual_seed` call only accept "full mesh" i.e. the DTensor RNG state on every rank must be set through the call. This makes sure that no rank has its RNG state left uninitialized and the SPMD ranks have their RNG state synchronous. **Motivation** tl;dr 1. Lazily initializing DTensor RNG tracker causes hang in non-SPMD code such as Pipeline Parallel. 2. Users may want to set different seed on ranks in one device mesh. 3. We want to keep the old behavior if users prefer not curating the RNG state and want to have DTensor take care of it. see detail in pytorch#140301 **Test** `pytest test/distributed/_tensor/test_random_ops.py` `pytest test/distributed/tensor/parallel/test_tp_random_state.py` Pull Request resolved: pytorch#141223 Approved by: https://github.com/wconstab ghstack dependencies: pytorch#141731, pytorch#141220
… avoid overflow (pytorch#141532) **Summary** DTensor RNG code raises error if the seed passed in is beyong `torch.int64` range (e.g. `torch.tensor([2**64-1])` raises error). The solution is to specify the `dtype=torch.uint64` in the `torch.tensor()` call. Pull Request resolved: pytorch#141532 Approved by: https://github.com/wconstab ghstack dependencies: pytorch#141731, pytorch#141220, pytorch#141223


Stack from ghstack (oldest at bottom):
Summary
This PR proposes 4 changes to DTensor RNG management:
torch.distributed.tensor._random.manual_seed.manual_seedno longer checks the integrity of theseedargument. Users are responsible for setting the same seed on all ranks within an SPMD group, but if there are multiple separate SPMD groups (e.g. across pipeline stages), users should set a different seed for each SPMD group. For cases like Pipeline Parallel, users can set different initial seed for pipelining stages by callingIn other word, if users want to call
torch.distributed.tensor._random.manual_seed, they will be responsible for passing in the right value and DTensor won't perform any checks on it. If the current rank is not a part of the mesh, it will use the current device RNG state to initialize.OffsetBasedRNGTrackerstill performs RNG state synchronization by broadcasting the RNG state on rank 0 toWORLD. However, callingtorch.distributed.tensor._random.manual_seedis an exception. In this case, no broadcast will happen.Enforce that the
manual_seedcall only accept "full mesh" i.e. the DTensor RNG state on every rank must be set through the call. This makes sure that no rank has its RNG state left uninitialized and the SPMD ranks have their RNG state synchronous.Motivation
tl;dr
see detail in #140301
Test
pytest test/distributed/_tensor/test_random_ops.pypytest test/distributed/tensor/parallel/test_tp_random_state.pycc @wanchaol @tianyu-l @wz337 @d4l3k @H-Huang @awgu @kwen2501 @fegin @fduwjj @wconstab @c-p-i-occ @c-p-i-o