KEMBAR78
[dtensor][random] allow user to manual_seed different seed on device mesh; only sync RNG state in WORLD when manual_seed has not been called by XilunWu · Pull Request #141223 · pytorch/pytorch · GitHub
Skip to content

Conversation

@XilunWu
Copy link
Contributor

@XilunWu XilunWu commented Nov 21, 2024

Stack from ghstack (oldest at bottom):

Summary
This PR proposes 4 changes to DTensor RNG management:

  1. DTensor allows users to eagerly initialize the RNG tracker by calling torch.distributed.tensor._random.manual_seed.
  2. DTensor manual_seed no longer checks the integrity of the seed argument. Users are responsible for setting the same seed on all ranks within an SPMD group, but if there are multiple separate SPMD groups (e.g. across pipeline stages), users should set a different seed for each SPMD group. For cases like Pipeline Parallel, users can set different initial seed for pipelining stages by calling
world_mesh = init_device_mesh(
    device_type="cuda",
    mesh_shape=(2, 2, 2),
    mesh_dim_names=("pp", "dp", "tp"),
)
pp_mesh = world_mesh["pp"]
pp_rank = pp_mesh.get_local_rank()
spmd_mesh = world_mesh["dp", "tp"]._flatten("spmd")  # this flattening is only needed if you need to call collective over this mesh
torch.distributed.tensor._random.manual_seed(123+pp_rank, spmd_mesh)

In other word, if users want to call torch.distributed.tensor._random.manual_seed, they will be responsible for passing in the right value and DTensor won't perform any checks on it. If the current rank is not a part of the mesh, it will use the current device RNG state to initialize.

  1. OffsetBasedRNGTracker still performs RNG state synchronization by broadcasting the RNG state on rank 0 to WORLD. However, calling torch.distributed.tensor._random.manual_seed is an exception. In this case, no broadcast will happen.

  2. Enforce that the manual_seed call only accept "full mesh" i.e. the DTensor RNG state on every rank must be set through the call. This makes sure that no rank has its RNG state left uninitialized and the SPMD ranks have their RNG state synchronous.

Motivation
tl;dr

  1. Lazily initializing DTensor RNG tracker causes hang in non-SPMD code such as Pipeline Parallel.
  2. Users may want to set different seed on ranks in one device mesh.
  3. We want to keep the old behavior if users prefer not curating the RNG state and want to have DTensor take care of it.

see detail in #140301

Test
pytest test/distributed/_tensor/test_random_ops.py
pytest test/distributed/tensor/parallel/test_tp_random_state.py

cc @wanchaol @tianyu-l @wz337 @d4l3k @H-Huang @awgu @kwen2501 @fegin @fduwjj @wconstab @c-p-i-occ @c-p-i-o

…mesh; only sync RNG state in WORLD when manual_seed has not been called

[ghstack-poisoned]
@pytorch-bot
Copy link

pytorch-bot bot commented Nov 21, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/141223

Note: Links to docs will display an error until the docs builds have been completed.

✅ You can merge normally! (1 Unrelated Failure)

As of commit 366bd1c with merge base 6a22cae (image):

FLAKY - The following job failed but was likely due to flakiness present on trunk:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

XilunWu added a commit that referenced this pull request Nov 21, 2024
…mesh; only sync RNG state in WORLD when manual_seed has not been called

ghstack-source-id: e90f3ff
Pull Request resolved: #141223
@pytorch-bot pytorch-bot bot added the oncall: distributed Add this issue/PR to distributed oncall triage queue label Nov 21, 2024
@XilunWu XilunWu added module: dtensor distributed tensor tag release notes: distributed (dtensor) release notes category and removed oncall: distributed Add this issue/PR to distributed oncall triage queue labels Nov 21, 2024
@pytorch-bot pytorch-bot bot added the oncall: distributed Add this issue/PR to distributed oncall triage queue label Nov 21, 2024
@wconstab
Copy link
Contributor

looks pretty good to me. I want to make a few changes to the PR desc for other users to understand better.

image Update this code snippet to show how pp_device_mesh is defined. (it should be a submesh of the world that includes all the dimensions _other than_ PP dim. I would think the better variable name is `non_pp_mesh` or `spmd_mesh`
  1. change this text
image Users are responsible for setting the same seed on all ranks within an SPMD group, but if there are multiple separate SPMD groups (e.g. across pipeline stages), users should set a _different_ seed for each SPMD group.

When calling this function, :func:`manual_seed` must be called from all ranks of the
default ``ProcessGroup`` even if some ranks may not be a part of the ``device_mesh``,
with the same ``seed`` value.
:func:`manual_seed` does not check the ``seed`` value correctness. Users must
Copy link
Contributor

@wconstab wconstab Nov 21, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

for a future PR- I wonder if we want to put 'manual_seed' into the tensor/__init__.py namespace?

along with that change, i think we need a tutorial and a composability test that hopefully share the same code snippet and test that we can initialize 8gpu with pp + 2D spmd and show a toy module with sufficient complexity (e.g. has replicated bias and sharded param), gets different weights on expected ranks and same weights on expected ranks.

Copy link
Contributor

@tianyu-l tianyu-l left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree with and like everything @wconstab said!

@wconstab
Copy link
Contributor

wconstab commented Nov 23, 2024

@XilunWu FYI i got a PR on torchtitan side which uses this PR. It crashes on something maybe not supported well inside OffsetBasedRNGTracker. Maybe this is a regression after not using TensorParallelRNGTracker for TorchTitan?

Pytorch branch: this PR
Torchtitan branch: pytorch/torchtitan#689
Run command: TORCH_LOGS=+pp CONFIG_FILE=./train_configs/debug_model.toml NGPU=8 LOG_RANK=0,1,2,3,4,5,6,7 ./run_llama_train.sh --job.dump_folder outputs/3d_compile --model.flavor debugmodel --experimental.pipeline_parallel_degree 2 --experimental.pipeline_parallel_split_points layers.4 --training.data_parallel_shard_degree 2 --training.tensor_parallel_degree 2

Error:

[rank5]:[rank5]:     layer.init_weights()                                                                                                 16:46:46 [545/4895]
[rank5]:[rank5]:   File "/data/users/whc/torchtitan/torchtitan/models/llama/model.py", line 331, in init_weights                                             
[rank5]:[rank5]:     self.attention.init_weights(self.weight_init_std)                                                                                       
[rank5]:[rank5]:   File "/data/users/whc/torchtitan/torchtitan/models/llama/model.py", line 170, in init_weights                                             
[rank5]:[rank5]:     nn.init.trunc_normal_(linear.weight, mean=0.0, std=0.02)                                                                                
[rank5]:[rank5]:   File "/data/users/whc/pytorch/torch/nn/init.py", line 224, in trunc_normal_                                                               
[rank5]:[rank5]:     return _no_grad_trunc_normal_(tensor, mean, std, a, b, generator=generator)                                                             
[rank5]:[rank5]:   File "/data/users/whc/pytorch/torch/nn/init.py", line 47, in _no_grad_trunc_normal_                                                       
[rank5]:[rank5]:     tensor.uniform_(2 * l - 1, 2 * u - 1, generator=generator)                                                                              
[rank5]:[rank5]:   File "/data/users/whc/pytorch/torch/_compile.py", line 32, in inner                                                                       
[rank5]:[rank5]:     return disable_fn(*args, **kwargs)                                                                                                      
[rank5]:[rank5]:   File "/data/users/whc/pytorch/torch/_dynamo/eval_frame.py", line 721, in _fn                                                              
[rank5]:[rank5]:     return fn(*args, **kwargs)                                                                                                              
[rank5]:[rank5]:   File "/data/users/whc/pytorch/torch/distributed/tensor/_api.py", line 343, in __torch_dispatch__                                          
[rank5]:[rank5]:     return DTensor._op_dispatcher.dispatch(                                                                                                 
[rank5]:[rank5]:   File "/data/users/whc/pytorch/torch/distributed/tensor/_dispatch.py", line 211, in dispatch                                               
[rank5]:[rank5]:     with rng_context:                                                                                                                       
[rank5]:[rank5]:   File "/home/whc/.conda/envs/pytorch-3.10/lib/python3.10/contextlib.py", line 135, in __enter__                                            
[rank5]:[rank5]:     return next(self.gen)                                                                                                                   
[rank5]:[rank5]:   File "/data/users/whc/pytorch/torch/distributed/tensor/_random.py", line 176, in _distribute_region                                       
[rank5]:[rank5]:     self._set_pre_op_offset(spec)                                                                                                           
[rank5]:[rank5]:   File "/data/users/whc/pytorch/torch/distributed/tensor/_random.py", line 255, in _set_pre_op_offset                                       
[rank5]:[rank5]:     dim_map = spec.dim_map                                                                                                                  
[rank5]:[rank5]:   File "/data/users/whc/pytorch/torch/distributed/tensor/_dtensor_spec.py", line 164, in dim_map                                            
[rank5]:[rank5]:     raise ValueError(                                                                                                                       
[rank5]:[rank5]: ValueError: Tensor dim 0 is already sharded on mesh dim 0, DTensor operator implementation does not support things like hybrid sharding stra
tegies yet (i.e. [Shard(0), Shard(0)])  

… on device mesh; only sync RNG state in WORLD when manual_seed has not been called"


**Summary**
This PR proposes 3 changes to DTensor RNG management:
1. DTensor allows users to eagerly initialize the RNG tracker by calling `torch.distributed.tensor._random.manual_seed`.
2. DTensor `manual_seed` no longer checks the integrity of the `seed` argument. Users are responsible for setting the same seed on all ranks within an SPMD group, but if there are multiple separate SPMD groups (e.g. across pipeline stages), users should set a _different_ seed for each SPMD group. For cases like Pipeline Parallel, users can set different initial seed for pipelining stages by calling
```
world_mesh = init_device_mesh(
    device_type="cuda",
    mesh_shape=(2, 2, 2),
    mesh_dim_names=("pp", "dp", "tp"),
)
pp_mesh = world_mesh["pp"]
pp_rank = pp_mesh.get_local_rank()
spmd_mesh = world_mesh["dp", "tp"]._flatten("spmd")
torch.distributed.tensor._random.manual_seed(123+pp_rank, spmd_mesh)
```

In other word, if users want to call `torch.distributed.tensor._random.manual_seed`, they will be responsible for passing in the right value and DTensor won't perform any checks on it. If the current rank is not a part of the mesh, it will use the current device RNG state to initialize.

3. `OffsetBasedRNGTracker` still performs RNG state synchronization by broadcasting the RNG state on rank 0 to `WORLD`. However, calling `torch.distributed.tensor._random.manual_seed` is an exception. In this case, no broadcast will happen.

**Motivation**
tl;dr

1. Lazily initializing DTensor RNG tracker causes hang in non-SPMD code such as Pipeline Parallel.
2. Users may want to set different seed on ranks in one device mesh.
3. We want to keep the old behavior if users prefer not curating the RNG state and want to have DTensor take care of it.

see detail in #140301

**Test**
`pytest test/distributed/_tensor/test_random_ops.py`
`pytest test/distributed/tensor/parallel/test_tp_random_state.py`



cc wanchaol tianyu-l wz337 d4l3k H-Huang awgu kwen2501 fegin fduwjj wconstab c-p-i-occ c-p-i-o

[ghstack-poisoned]
XilunWu added a commit that referenced this pull request Nov 26, 2024
…mesh; only sync RNG state in WORLD when manual_seed has not been called

ghstack-source-id: 2352d74
Pull Request resolved: #141223
… on device mesh; only sync RNG state in WORLD when manual_seed has not been called"


**Summary**
This PR proposes 3 changes to DTensor RNG management:
1. DTensor allows users to eagerly initialize the RNG tracker by calling `torch.distributed.tensor._random.manual_seed`.
2. DTensor `manual_seed` no longer checks the integrity of the `seed` argument. Users are responsible for setting the same seed on all ranks within an SPMD group, but if there are multiple separate SPMD groups (e.g. across pipeline stages), users should set a _different_ seed for each SPMD group. For cases like Pipeline Parallel, users can set different initial seed for pipelining stages by calling
```
world_mesh = init_device_mesh(
    device_type="cuda",
    mesh_shape=(2, 2, 2),
    mesh_dim_names=("pp", "dp", "tp"),
)
pp_mesh = world_mesh["pp"]
pp_rank = pp_mesh.get_local_rank()
spmd_mesh = world_mesh["dp", "tp"]._flatten("spmd")
torch.distributed.tensor._random.manual_seed(123+pp_rank, spmd_mesh)
```

In other word, if users want to call `torch.distributed.tensor._random.manual_seed`, they will be responsible for passing in the right value and DTensor won't perform any checks on it. If the current rank is not a part of the mesh, it will use the current device RNG state to initialize.

3. `OffsetBasedRNGTracker` still performs RNG state synchronization by broadcasting the RNG state on rank 0 to `WORLD`. However, calling `torch.distributed.tensor._random.manual_seed` is an exception. In this case, no broadcast will happen.

**Motivation**
tl;dr

1. Lazily initializing DTensor RNG tracker causes hang in non-SPMD code such as Pipeline Parallel.
2. Users may want to set different seed on ranks in one device mesh.
3. We want to keep the old behavior if users prefer not curating the RNG state and want to have DTensor take care of it.

see detail in #140301

**Test**
`pytest test/distributed/_tensor/test_random_ops.py`
`pytest test/distributed/tensor/parallel/test_tp_random_state.py`



cc wanchaol tianyu-l wz337 d4l3k H-Huang awgu kwen2501 fegin fduwjj wconstab c-p-i-occ c-p-i-o

[ghstack-poisoned]
… on device mesh; only sync RNG state in WORLD when manual_seed has not been called"


**Summary**
This PR proposes 3 changes to DTensor RNG management:
1. DTensor allows users to eagerly initialize the RNG tracker by calling `torch.distributed.tensor._random.manual_seed`.
2. DTensor `manual_seed` no longer checks the integrity of the `seed` argument. Users are responsible for setting the same seed on all ranks within an SPMD group, but if there are multiple separate SPMD groups (e.g. across pipeline stages), users should set a _different_ seed for each SPMD group. For cases like Pipeline Parallel, users can set different initial seed for pipelining stages by calling
```
world_mesh = init_device_mesh(
    device_type="cuda",
    mesh_shape=(2, 2, 2),
    mesh_dim_names=("pp", "dp", "tp"),
)
pp_mesh = world_mesh["pp"]
pp_rank = pp_mesh.get_local_rank()
spmd_mesh = world_mesh["dp", "tp"]._flatten("spmd")
torch.distributed.tensor._random.manual_seed(123+pp_rank, spmd_mesh)
```

In other word, if users want to call `torch.distributed.tensor._random.manual_seed`, they will be responsible for passing in the right value and DTensor won't perform any checks on it. If the current rank is not a part of the mesh, it will use the current device RNG state to initialize.

3. `OffsetBasedRNGTracker` still performs RNG state synchronization by broadcasting the RNG state on rank 0 to `WORLD`. However, calling `torch.distributed.tensor._random.manual_seed` is an exception. In this case, no broadcast will happen.

**Motivation**
tl;dr

1. Lazily initializing DTensor RNG tracker causes hang in non-SPMD code such as Pipeline Parallel.
2. Users may want to set different seed on ranks in one device mesh.
3. We want to keep the old behavior if users prefer not curating the RNG state and want to have DTensor take care of it.

see detail in #140301

**Test**
`pytest test/distributed/_tensor/test_random_ops.py`
`pytest test/distributed/tensor/parallel/test_tp_random_state.py`



cc wanchaol tianyu-l wz337 d4l3k H-Huang awgu kwen2501 fegin fduwjj wconstab c-p-i-occ c-p-i-o

[ghstack-poisoned]
@XilunWu
Copy link
Contributor Author

XilunWu commented Nov 26, 2024

successfully run test in pytorch/torchtitan#689 after rebasing the PR on top.

@XilunWu XilunWu requested review from tianyu-l and wconstab November 26, 2024 01:13
… on device mesh; only sync RNG state in WORLD when manual_seed has not been called"


**Summary**
This PR proposes 3 changes to DTensor RNG management:
1. DTensor allows users to eagerly initialize the RNG tracker by calling `torch.distributed.tensor._random.manual_seed`.
2. DTensor `manual_seed` no longer checks the integrity of the `seed` argument. Users are responsible for setting the same seed on all ranks within an SPMD group, but if there are multiple separate SPMD groups (e.g. across pipeline stages), users should set a _different_ seed for each SPMD group. For cases like Pipeline Parallel, users can set different initial seed for pipelining stages by calling
```
world_mesh = init_device_mesh(
    device_type="cuda",
    mesh_shape=(2, 2, 2),
    mesh_dim_names=("pp", "dp", "tp"),
)
pp_mesh = world_mesh["pp"]
pp_rank = pp_mesh.get_local_rank()
spmd_mesh = world_mesh["dp", "tp"]._flatten("spmd")
torch.distributed.tensor._random.manual_seed(123+pp_rank, spmd_mesh)
```

In other word, if users want to call `torch.distributed.tensor._random.manual_seed`, they will be responsible for passing in the right value and DTensor won't perform any checks on it. If the current rank is not a part of the mesh, it will use the current device RNG state to initialize.

3. `OffsetBasedRNGTracker` still performs RNG state synchronization by broadcasting the RNG state on rank 0 to `WORLD`. However, calling `torch.distributed.tensor._random.manual_seed` is an exception. In this case, no broadcast will happen.

**Motivation**
tl;dr

1. Lazily initializing DTensor RNG tracker causes hang in non-SPMD code such as Pipeline Parallel.
2. Users may want to set different seed on ranks in one device mesh.
3. We want to keep the old behavior if users prefer not curating the RNG state and want to have DTensor take care of it.

see detail in #140301

**Test**
`pytest test/distributed/_tensor/test_random_ops.py`
`pytest test/distributed/tensor/parallel/test_tp_random_state.py`



cc wanchaol tianyu-l wz337 d4l3k H-Huang awgu kwen2501 fegin fduwjj wconstab c-p-i-occ c-p-i-o

[ghstack-poisoned]
@wconstab
Copy link
Contributor

spmd_mesh = world_mesh["dp", "tp"]._flatten("spmd")

Just curious: is this part necessary? or can we just pass world_mesh["dp", "tp"] into manual_seed?

If the current rank is not a part of the mesh, it will use the current device RNG state to initialize.

This sounds wrong to me- if the user passes a mesh that is supposed to be the 'spmd mesh', and the current rank is not a part of that mesh, shouldn't we raise an error?

)

# the current rank is in mesh
if device_mesh.get_coordinate() is not None:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should we assert this instead of silent-no-op-ing it?

if we require users to pass in a device mesh that is defined as the 'spmd world', then it is wholly reasonable to make it error out if the mesh does not contain our rank.

Copy link
Contributor

@wconstab wconstab left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

approving as long as you add the runtime error in manual_seed for missing ranks.

… on device mesh; only sync RNG state in WORLD when manual_seed has not been called"


**Summary**
This PR proposes 3 changes to DTensor RNG management:
1. DTensor allows users to eagerly initialize the RNG tracker by calling `torch.distributed.tensor._random.manual_seed`.
2. DTensor `manual_seed` no longer checks the integrity of the `seed` argument. Users are responsible for setting the same seed on all ranks within an SPMD group, but if there are multiple separate SPMD groups (e.g. across pipeline stages), users should set a _different_ seed for each SPMD group. For cases like Pipeline Parallel, users can set different initial seed for pipelining stages by calling
```
world_mesh = init_device_mesh(
    device_type="cuda",
    mesh_shape=(2, 2, 2),
    mesh_dim_names=("pp", "dp", "tp"),
)
pp_mesh = world_mesh["pp"]
pp_rank = pp_mesh.get_local_rank()
spmd_mesh = world_mesh["dp", "tp"]._flatten("spmd")  # this flattening is only needed if you need to call collective over this mesh
torch.distributed.tensor._random.manual_seed(123+pp_rank, spmd_mesh)
```

In other word, if users want to call `torch.distributed.tensor._random.manual_seed`, they will be responsible for passing in the right value and DTensor won't perform any checks on it. If the current rank is not a part of the mesh, it will use the current device RNG state to initialize.

3. `OffsetBasedRNGTracker` still performs RNG state synchronization by broadcasting the RNG state on rank 0 to `WORLD`. However, calling `torch.distributed.tensor._random.manual_seed` is an exception. In this case, no broadcast will happen.

**Motivation**
tl;dr

1. Lazily initializing DTensor RNG tracker causes hang in non-SPMD code such as Pipeline Parallel.
2. Users may want to set different seed on ranks in one device mesh.
3. We want to keep the old behavior if users prefer not curating the RNG state and want to have DTensor take care of it.

see detail in #140301

**Test**
`pytest test/distributed/_tensor/test_random_ops.py`
`pytest test/distributed/tensor/parallel/test_tp_random_state.py`



cc wanchaol tianyu-l wz337 d4l3k H-Huang awgu kwen2501 fegin fduwjj wconstab c-p-i-occ c-p-i-o

[ghstack-poisoned]
seed (int): The desired seed.
device_mesh (:class:`DeviceMesh`): The device mesh to set the seed.
device_mesh (:class:`DeviceMesh`): The device mesh to set the seed. It is
required that the ``device_mesh`` include the calling rank. This is
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: "include the calling rank" is too weak a statement.

It is required that 'device_mesh' contains the superset of all meshes that will be used with this DTensor. It would be good to define 'spmd world' and then say that this mesh must be the spmd_world.

pytorchmergebot pushed a commit that referenced this pull request Nov 29, 2024
… avoid overflow (#141532)

**Summary**
DTensor RNG code raises error if the seed passed in is beyong `torch.int64` range (e.g. `torch.tensor([2**64-1])` raises error). The solution is to specify the `dtype=torch.uint64` in the `torch.tensor()` call.

Pull Request resolved: #141532
Approved by: https://github.com/wconstab
ghstack dependencies: #141731, #141220, #141223
GeorgeWigley pushed a commit to graphcore/pytorch-fork that referenced this pull request Nov 29, 2024
…mesh; only sync RNG state in WORLD when manual_seed has not been called (pytorch#141223)

**Summary**
This PR proposes 4 changes to DTensor RNG management:
1. DTensor allows users to eagerly initialize the RNG tracker by calling `torch.distributed.tensor._random.manual_seed`.
2. DTensor `manual_seed` no longer checks the integrity of the `seed` argument. Users are responsible for setting the same seed on all ranks within an SPMD group, but if there are multiple separate SPMD groups (e.g. across pipeline stages), users should set a _different_ seed for each SPMD group. For cases like Pipeline Parallel, users can set different initial seed for pipelining stages by calling
```
world_mesh = init_device_mesh(
    device_type="cuda",
    mesh_shape=(2, 2, 2),
    mesh_dim_names=("pp", "dp", "tp"),
)
pp_mesh = world_mesh["pp"]
pp_rank = pp_mesh.get_local_rank()
spmd_mesh = world_mesh["dp", "tp"]._flatten("spmd")  # this flattening is only needed if you need to call collective over this mesh
torch.distributed.tensor._random.manual_seed(123+pp_rank, spmd_mesh)
```

In other word, if users want to call `torch.distributed.tensor._random.manual_seed`, they will be responsible for passing in the right value and DTensor won't perform any checks on it. If the current rank is not a part of the mesh, it will use the current device RNG state to initialize.

3. `OffsetBasedRNGTracker` still performs RNG state synchronization by broadcasting the RNG state on rank 0 to `WORLD`. However, calling `torch.distributed.tensor._random.manual_seed` is an exception. In this case, no broadcast will happen.

4. Enforce that the `manual_seed` call only accept "full mesh" i.e. the DTensor RNG state on every rank must be set through the call. This makes sure that no rank has its RNG state left uninitialized and the SPMD ranks have their RNG state synchronous.

**Motivation**
tl;dr

1. Lazily initializing DTensor RNG tracker causes hang in non-SPMD code such as Pipeline Parallel.
2. Users may want to set different seed on ranks in one device mesh.
3. We want to keep the old behavior if users prefer not curating the RNG state and want to have DTensor take care of it.

see detail in pytorch#140301

**Test**
`pytest test/distributed/_tensor/test_random_ops.py`
`pytest test/distributed/tensor/parallel/test_tp_random_state.py`

Pull Request resolved: pytorch#141223
Approved by: https://github.com/wconstab
ghstack dependencies: pytorch#141731, pytorch#141220
GeorgeWigley pushed a commit to graphcore/pytorch-fork that referenced this pull request Nov 29, 2024
… avoid overflow (pytorch#141532)

**Summary**
DTensor RNG code raises error if the seed passed in is beyong `torch.int64` range (e.g. `torch.tensor([2**64-1])` raises error). The solution is to specify the `dtype=torch.uint64` in the `torch.tensor()` call.

Pull Request resolved: pytorch#141532
Approved by: https://github.com/wconstab
ghstack dependencies: pytorch#141731, pytorch#141220, pytorch#141223
pobin6 pushed a commit to pobin6/pytorch that referenced this pull request Dec 5, 2024
…mesh; only sync RNG state in WORLD when manual_seed has not been called (pytorch#141223)

**Summary**
This PR proposes 4 changes to DTensor RNG management:
1. DTensor allows users to eagerly initialize the RNG tracker by calling `torch.distributed.tensor._random.manual_seed`.
2. DTensor `manual_seed` no longer checks the integrity of the `seed` argument. Users are responsible for setting the same seed on all ranks within an SPMD group, but if there are multiple separate SPMD groups (e.g. across pipeline stages), users should set a _different_ seed for each SPMD group. For cases like Pipeline Parallel, users can set different initial seed for pipelining stages by calling
```
world_mesh = init_device_mesh(
    device_type="cuda",
    mesh_shape=(2, 2, 2),
    mesh_dim_names=("pp", "dp", "tp"),
)
pp_mesh = world_mesh["pp"]
pp_rank = pp_mesh.get_local_rank()
spmd_mesh = world_mesh["dp", "tp"]._flatten("spmd")  # this flattening is only needed if you need to call collective over this mesh
torch.distributed.tensor._random.manual_seed(123+pp_rank, spmd_mesh)
```

In other word, if users want to call `torch.distributed.tensor._random.manual_seed`, they will be responsible for passing in the right value and DTensor won't perform any checks on it. If the current rank is not a part of the mesh, it will use the current device RNG state to initialize.

3. `OffsetBasedRNGTracker` still performs RNG state synchronization by broadcasting the RNG state on rank 0 to `WORLD`. However, calling `torch.distributed.tensor._random.manual_seed` is an exception. In this case, no broadcast will happen.

4. Enforce that the `manual_seed` call only accept "full mesh" i.e. the DTensor RNG state on every rank must be set through the call. This makes sure that no rank has its RNG state left uninitialized and the SPMD ranks have their RNG state synchronous.

**Motivation**
tl;dr

1. Lazily initializing DTensor RNG tracker causes hang in non-SPMD code such as Pipeline Parallel.
2. Users may want to set different seed on ranks in one device mesh.
3. We want to keep the old behavior if users prefer not curating the RNG state and want to have DTensor take care of it.

see detail in pytorch#140301

**Test**
`pytest test/distributed/_tensor/test_random_ops.py`
`pytest test/distributed/tensor/parallel/test_tp_random_state.py`

Pull Request resolved: pytorch#141223
Approved by: https://github.com/wconstab
ghstack dependencies: pytorch#141731, pytorch#141220
pobin6 pushed a commit to pobin6/pytorch that referenced this pull request Dec 5, 2024
… avoid overflow (pytorch#141532)

**Summary**
DTensor RNG code raises error if the seed passed in is beyong `torch.int64` range (e.g. `torch.tensor([2**64-1])` raises error). The solution is to specify the `dtype=torch.uint64` in the `torch.tensor()` call.

Pull Request resolved: pytorch#141532
Approved by: https://github.com/wconstab
ghstack dependencies: pytorch#141731, pytorch#141220, pytorch#141223
@github-actions github-actions bot deleted the gh/XilunWu/104/head branch December 30, 2024 02:08
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Merged module: dtensor distributed tensor tag oncall: distributed Add this issue/PR to distributed oncall triage queue release notes: distributed (dtensor) release notes category

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants