KEMBAR78
[dtensor][random][tp] remove the adhoc DTensor RNG tracker TensorParallelRNGTracker since it does not match FSDP2+TP by XilunWu · Pull Request #141220 · pytorch/pytorch · GitHub
Skip to content

Conversation

@XilunWu
Copy link
Contributor

@XilunWu XilunWu commented Nov 21, 2024

Stack from ghstack (oldest at bottom):

Summary
The ad-hoc DTensor RNG tracker was used to mimic Megatron DDP+TP RNG behavior but it turns out not compatible with PyTorch Distributed FSDP2+TP so we decide to deprecate it and use OffsetBasedRNGTracker to replace, which follows the SPMD semantics (replicas get the same random sampling result, shards get different results).

Motivation
TensorParallelRNGTracker was designed for DDP+TP where the random operators produce the same result along the data parallel mesh dimension and different results along the tensor parallel dimension. However this does not apply to the new FSDP+TP composable combination where the model weights are sharded along data parallel mesh dimension as well. Therefore we decide to remove this outdated RNG tracker type for now. If users have demands for exact match between PyTorch Distributed and Megatron on Random Number generation result, feel free to file an issue.

Impact
TensorParallelRNGTracker was only used when Tensor Parallel is used (i.e. calling parallelize_module).

For non-FSDP users, the "replicas get the same random numbers and shards get different ones" remains unchanged. Unlike TensorParallelRNGTracker which sets different seeds (base_seed + 2718 + TP_rank) within the TP group, DTensor now sets the same seed (default value is 1234 but users can call torch.distributed.tensor._random.manual_seed to modify) on all ranks but choose the right RNG offset based on DTensor placements to enforce the "replicas get the same random numbers and shards get different ones" invariant.

For FSDP2 users, improvement should be observed in a way that DTensor sharded within DP group now gets different random number sampling which TensorParallelRNGTracker failed to do, though we're not sure how much this change will improve the eventual training loss convergence.

Test
1-d model weight meta init:
pytest test/distributed/_tensor/test_random_ops.py -s -k test_tp_model_meta_init

2-d model weight meta init:
pytest test/distributed/_tensor/test_random_ops.py -s -k test_fsdp_tp_model_meta_init

TP model weight init test:
pytest test/distributed/tensor/parallel/test_tp_random_state.py

FSDP+TP model weight init test:
pytest test/distributed/_composable/fsdp/test_fully_shard_init.py

cc @H-Huang @awgu @kwen2501 @wanchaol @fegin @fduwjj @wz337 @wconstab @d4l3k @c-p-i-o @tianyu-l

…llelRNGTracker since it does not match FSDP2+TP

[ghstack-poisoned]
@pytorch-bot pytorch-bot bot added the oncall: distributed Add this issue/PR to distributed oncall triage queue label Nov 21, 2024
XilunWu added a commit that referenced this pull request Nov 21, 2024
…llelRNGTracker since it does not match FSDP2+TP

ghstack-source-id: a78a69c
Pull Request resolved: #141220
@pytorch-bot
Copy link

pytorch-bot bot commented Nov 21, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/141220

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit fb5892a with merge base 6a22cae (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@wconstab
Copy link
Contributor

I think the change looks good, but we should offer more thorough documentation of the reasons and consequences of the change in the PR desc.

  1. why TensorParallelRNGTracker was wrong?
  2. in what situations will existing users be impacted? in what situations will existing users show no change
  3. did we do any test to ensure OffsetBasedRNGTracker actually works as expected for initializing a 1D, 2D model?

… TensorParallelRNGTracker since it does not match FSDP2+TP"


**Summary**
The ad-hoc DTensor RNG tracker was used to mimic Megatron DDP+TP RNG behavior but it turns out not compatible with PyTorch Distributed FSDP2+TP so we decide to deprecate it and use `OffsetBasedRNGTracker` to replace, which follows the SPMD semantics (replicas get the same random sampling result, shards get different results).

**Motivation**
`TensorParallelRNGTracker` was designed for DDP+TP where the random operators produce the same result along the data parallel mesh dimension and different results along the tensor parallel dimension. However this does not apply to the new FSDP+TP composable combination where the model weights are sharded along data parallel mesh dimension as well. Therefore we decide to remove this outdated RNG tracker type for now. If users have demands for exact match between PyTorch Distributed and Megatron on Random Number generation result, feel free to file an issue.

**Impact**
`TensorParallelRNGTracker` was only used when Tensor Parallel is used (i.e. calling `parallelize_module`). 

For non-FSDP users, the "replicas get the same random numbers and shards get different ones" remains unchanged. Unlike `TensorParallelRNGTracker` which sets different seeds (`base_seed + 2718 + TP_rank`) within the TP group, DTensor now sets the same seed (default value is 1234 but users can call `torch.distributed.tensor._random.manual_seed` to modify) on all ranks but choose the right RNG offset based on DTensor placements to enforce the "replicas get the same random numbers and shards get different ones" invariant. 

For FSDP2 users, improvement should be observed in a way that DTensor sharded within DP group now gets different random number sampling which `TensorParallelRNGTracker` failed to do, though we're not sure how much this change will improve the eventual training loss convergence.

**Test**
TP model weight init test: `pytest test/distributed/tensor/parallel/test_tp_random_state.py`
FSDP+TP model weight init test: `pytest test/distributed/_composable/fsdp/test_fully_shard_init.py`

cc H-Huang awgu kwen2501 wanchaol fegin fduwjj wz337 wconstab d4l3k c-p-i-o tianyu-l

[ghstack-poisoned]
… TensorParallelRNGTracker since it does not match FSDP2+TP"


**Summary**
The ad-hoc DTensor RNG tracker was used to mimic Megatron DDP+TP RNG behavior but it turns out not compatible with PyTorch Distributed FSDP2+TP so we decide to deprecate it and use `OffsetBasedRNGTracker` to replace, which follows the SPMD semantics (replicas get the same random sampling result, shards get different results).

**Motivation**
`TensorParallelRNGTracker` was designed for DDP+TP where the random operators produce the same result along the data parallel mesh dimension and different results along the tensor parallel dimension. However this does not apply to the new FSDP+TP composable combination where the model weights are sharded along data parallel mesh dimension as well. Therefore we decide to remove this outdated RNG tracker type for now. If users have demands for exact match between PyTorch Distributed and Megatron on Random Number generation result, feel free to file an issue.

**Impact**
`TensorParallelRNGTracker` was only used when Tensor Parallel is used (i.e. calling `parallelize_module`). 

For non-FSDP users, the "replicas get the same random numbers and shards get different ones" remains unchanged. Unlike `TensorParallelRNGTracker` which sets different seeds (`base_seed + 2718 + TP_rank`) within the TP group, DTensor now sets the same seed (default value is 1234 but users can call `torch.distributed.tensor._random.manual_seed` to modify) on all ranks but choose the right RNG offset based on DTensor placements to enforce the "replicas get the same random numbers and shards get different ones" invariant. 

For FSDP2 users, improvement should be observed in a way that DTensor sharded within DP group now gets different random number sampling which `TensorParallelRNGTracker` failed to do, though we're not sure how much this change will improve the eventual training loss convergence.

**Test**
TP model weight init test: `pytest test/distributed/tensor/parallel/test_tp_random_state.py`
FSDP+TP model weight init test: `pytest test/distributed/_composable/fsdp/test_fully_shard_init.py`

cc H-Huang awgu kwen2501 wanchaol fegin fduwjj wz337 wconstab d4l3k c-p-i-o tianyu-l

[ghstack-poisoned]
… TensorParallelRNGTracker since it does not match FSDP2+TP"


**Summary**
The ad-hoc DTensor RNG tracker was used to mimic Megatron DDP+TP RNG behavior but it turns out not compatible with PyTorch Distributed FSDP2+TP so we decide to deprecate it and use `OffsetBasedRNGTracker` to replace, which follows the SPMD semantics (replicas get the same random sampling result, shards get different results).

**Motivation**
`TensorParallelRNGTracker` was designed for DDP+TP where the random operators produce the same result along the data parallel mesh dimension and different results along the tensor parallel dimension. However this does not apply to the new FSDP+TP composable combination where the model weights are sharded along data parallel mesh dimension as well. Therefore we decide to remove this outdated RNG tracker type for now. If users have demands for exact match between PyTorch Distributed and Megatron on Random Number generation result, feel free to file an issue.

**Impact**
`TensorParallelRNGTracker` was only used when Tensor Parallel is used (i.e. calling `parallelize_module`). 

For non-FSDP users, the "replicas get the same random numbers and shards get different ones" remains unchanged. Unlike `TensorParallelRNGTracker` which sets different seeds (`base_seed + 2718 + TP_rank`) within the TP group, DTensor now sets the same seed (default value is 1234 but users can call `torch.distributed.tensor._random.manual_seed` to modify) on all ranks but choose the right RNG offset based on DTensor placements to enforce the "replicas get the same random numbers and shards get different ones" invariant. 

For FSDP2 users, improvement should be observed in a way that DTensor sharded within DP group now gets different random number sampling which `TensorParallelRNGTracker` failed to do, though we're not sure how much this change will improve the eventual training loss convergence.

**Test**
TP model weight init test: `pytest test/distributed/tensor/parallel/test_tp_random_state.py`
FSDP+TP model weight init test: `pytest test/distributed/_composable/fsdp/test_fully_shard_init.py`

cc H-Huang awgu kwen2501 wanchaol fegin fduwjj wz337 wconstab d4l3k c-p-i-o tianyu-l

[ghstack-poisoned]
@XilunWu XilunWu requested a review from tianyu-l November 26, 2024 01:13
… TensorParallelRNGTracker since it does not match FSDP2+TP"


**Summary**
The ad-hoc DTensor RNG tracker was used to mimic Megatron DDP+TP RNG behavior but it turns out not compatible with PyTorch Distributed FSDP2+TP so we decide to deprecate it and use `OffsetBasedRNGTracker` to replace, which follows the SPMD semantics (replicas get the same random sampling result, shards get different results).

**Motivation**
`TensorParallelRNGTracker` was designed for DDP+TP where the random operators produce the same result along the data parallel mesh dimension and different results along the tensor parallel dimension. However this does not apply to the new FSDP+TP composable combination where the model weights are sharded along data parallel mesh dimension as well. Therefore we decide to remove this outdated RNG tracker type for now. If users have demands for exact match between PyTorch Distributed and Megatron on Random Number generation result, feel free to file an issue.

**Impact**
`TensorParallelRNGTracker` was only used when Tensor Parallel is used (i.e. calling `parallelize_module`). 

For non-FSDP users, the "replicas get the same random numbers and shards get different ones" remains unchanged. Unlike `TensorParallelRNGTracker` which sets different seeds (`base_seed + 2718 + TP_rank`) within the TP group, DTensor now sets the same seed (default value is 1234 but users can call `torch.distributed.tensor._random.manual_seed` to modify) on all ranks but choose the right RNG offset based on DTensor placements to enforce the "replicas get the same random numbers and shards get different ones" invariant. 

For FSDP2 users, improvement should be observed in a way that DTensor sharded within DP group now gets different random number sampling which `TensorParallelRNGTracker` failed to do, though we're not sure how much this change will improve the eventual training loss convergence.

**Test**
TP model weight init test: `pytest test/distributed/tensor/parallel/test_tp_random_state.py`
FSDP+TP model weight init test: `pytest test/distributed/_composable/fsdp/test_fully_shard_init.py`

cc H-Huang awgu kwen2501 wanchaol fegin fduwjj wz337 wconstab d4l3k c-p-i-o tianyu-l

[ghstack-poisoned]
@wconstab
Copy link
Contributor

Test
TP model weight init test: pytest test/distributed/tensor/parallel/test_tp_random_state.py
FSDP+TP model weight init test: pytest test/distributed/_composable/fsdp/test_fully_shard_init.py

Can you say more about the test plan? I expect we need new tests or updates to existing tests. If existing tests were passing without this PR landing, then it proves they do not cover the case we need coverage for. Particularly, I would expect the test to use meta-init for a model that has tensor-parallel and fully_shard applied to it, move to empty, then initialize. We should check that each shard gets unique values. (and this should fail if we use the old TensorParallelRNGTracker)

… TensorParallelRNGTracker since it does not match FSDP2+TP"


**Summary**
The ad-hoc DTensor RNG tracker was used to mimic Megatron DDP+TP RNG behavior but it turns out not compatible with PyTorch Distributed FSDP2+TP so we decide to deprecate it and use `OffsetBasedRNGTracker` to replace, which follows the SPMD semantics (replicas get the same random sampling result, shards get different results).

**Motivation**
`TensorParallelRNGTracker` was designed for DDP+TP where the random operators produce the same result along the data parallel mesh dimension and different results along the tensor parallel dimension. However this does not apply to the new FSDP+TP composable combination where the model weights are sharded along data parallel mesh dimension as well. Therefore we decide to remove this outdated RNG tracker type for now. If users have demands for exact match between PyTorch Distributed and Megatron on Random Number generation result, feel free to file an issue.

**Impact**
`TensorParallelRNGTracker` was only used when Tensor Parallel is used (i.e. calling `parallelize_module`). 

For non-FSDP users, the "replicas get the same random numbers and shards get different ones" remains unchanged. Unlike `TensorParallelRNGTracker` which sets different seeds (`base_seed + 2718 + TP_rank`) within the TP group, DTensor now sets the same seed (default value is 1234 but users can call `torch.distributed.tensor._random.manual_seed` to modify) on all ranks but choose the right RNG offset based on DTensor placements to enforce the "replicas get the same random numbers and shards get different ones" invariant. 

For FSDP2 users, improvement should be observed in a way that DTensor sharded within DP group now gets different random number sampling which `TensorParallelRNGTracker` failed to do, though we're not sure how much this change will improve the eventual training loss convergence.

**Test**
TP model weight init test: `pytest test/distributed/tensor/parallel/test_tp_random_state.py`
FSDP+TP model weight init test: `pytest test/distributed/_composable/fsdp/test_fully_shard_init.py`

cc H-Huang awgu kwen2501 wanchaol fegin fduwjj wz337 wconstab d4l3k c-p-i-o tianyu-l

[ghstack-poisoned]
pytorchmergebot pushed a commit that referenced this pull request Nov 29, 2024
…mesh; only sync RNG state in WORLD when manual_seed has not been called (#141223)

**Summary**
This PR proposes 4 changes to DTensor RNG management:
1. DTensor allows users to eagerly initialize the RNG tracker by calling `torch.distributed.tensor._random.manual_seed`.
2. DTensor `manual_seed` no longer checks the integrity of the `seed` argument. Users are responsible for setting the same seed on all ranks within an SPMD group, but if there are multiple separate SPMD groups (e.g. across pipeline stages), users should set a _different_ seed for each SPMD group. For cases like Pipeline Parallel, users can set different initial seed for pipelining stages by calling
```
world_mesh = init_device_mesh(
    device_type="cuda",
    mesh_shape=(2, 2, 2),
    mesh_dim_names=("pp", "dp", "tp"),
)
pp_mesh = world_mesh["pp"]
pp_rank = pp_mesh.get_local_rank()
spmd_mesh = world_mesh["dp", "tp"]._flatten("spmd")  # this flattening is only needed if you need to call collective over this mesh
torch.distributed.tensor._random.manual_seed(123+pp_rank, spmd_mesh)
```

In other word, if users want to call `torch.distributed.tensor._random.manual_seed`, they will be responsible for passing in the right value and DTensor won't perform any checks on it. If the current rank is not a part of the mesh, it will use the current device RNG state to initialize.

3. `OffsetBasedRNGTracker` still performs RNG state synchronization by broadcasting the RNG state on rank 0 to `WORLD`. However, calling `torch.distributed.tensor._random.manual_seed` is an exception. In this case, no broadcast will happen.

4. Enforce that the `manual_seed` call only accept "full mesh" i.e. the DTensor RNG state on every rank must be set through the call. This makes sure that no rank has its RNG state left uninitialized and the SPMD ranks have their RNG state synchronous.

**Motivation**
tl;dr

1. Lazily initializing DTensor RNG tracker causes hang in non-SPMD code such as Pipeline Parallel.
2. Users may want to set different seed on ranks in one device mesh.
3. We want to keep the old behavior if users prefer not curating the RNG state and want to have DTensor take care of it.

see detail in #140301

**Test**
`pytest test/distributed/_tensor/test_random_ops.py`
`pytest test/distributed/tensor/parallel/test_tp_random_state.py`

Pull Request resolved: #141223
Approved by: https://github.com/wconstab
ghstack dependencies: #141731, #141220
pytorchmergebot pushed a commit that referenced this pull request Nov 29, 2024
… avoid overflow (#141532)

**Summary**
DTensor RNG code raises error if the seed passed in is beyong `torch.int64` range (e.g. `torch.tensor([2**64-1])` raises error). The solution is to specify the `dtype=torch.uint64` in the `torch.tensor()` call.

Pull Request resolved: #141532
Approved by: https://github.com/wconstab
ghstack dependencies: #141731, #141220, #141223
GeorgeWigley pushed a commit to graphcore/pytorch-fork that referenced this pull request Nov 29, 2024
…llelRNGTracker since it does not match FSDP2+TP (pytorch#141220)

**Summary**
The ad-hoc DTensor RNG tracker was used to mimic Megatron DDP+TP RNG behavior but it turns out not compatible with PyTorch Distributed FSDP2+TP so we decide to deprecate it and use `OffsetBasedRNGTracker` to replace, which follows the SPMD semantics (replicas get the same random sampling result, shards get different results).

**Motivation**
`TensorParallelRNGTracker` was designed for DDP+TP where the random operators produce the same result along the data parallel mesh dimension and different results along the tensor parallel dimension. However this does not apply to the new FSDP+TP composable combination where the model weights are sharded along data parallel mesh dimension as well. Therefore we decide to remove this outdated RNG tracker type for now. If users have demands for exact match between PyTorch Distributed and Megatron on Random Number generation result, feel free to file an issue.

**Impact**
`TensorParallelRNGTracker` was only used when Tensor Parallel is used (i.e. calling `parallelize_module`).

For non-FSDP users, the "replicas get the same random numbers and shards get different ones" remains unchanged. Unlike `TensorParallelRNGTracker` which sets different seeds (`base_seed + 2718 + TP_rank`) within the TP group, DTensor now sets the same seed (default value is 1234 but users can call `torch.distributed.tensor._random.manual_seed` to modify) on all ranks but choose the right RNG offset based on DTensor placements to enforce the "replicas get the same random numbers and shards get different ones" invariant.

For FSDP2 users, improvement should be observed in a way that DTensor sharded within DP group now gets different random number sampling which `TensorParallelRNGTracker` failed to do, though we're not sure how much this change will improve the eventual training loss convergence.

**Test**
1-d model weight meta init:
`pytest test/distributed/_tensor/test_random_ops.py -s -k test_tp_model_meta_init`

2-d model weight meta init:
`pytest test/distributed/_tensor/test_random_ops.py -s -k test_fsdp_tp_model_meta_init`

TP model weight init test:
`pytest test/distributed/tensor/parallel/test_tp_random_state.py`

FSDP+TP model weight init test:
`pytest test/distributed/_composable/fsdp/test_fully_shard_init.py`

Pull Request resolved: pytorch#141220
Approved by: https://github.com/wconstab
ghstack dependencies: pytorch#141731
GeorgeWigley pushed a commit to graphcore/pytorch-fork that referenced this pull request Nov 29, 2024
…mesh; only sync RNG state in WORLD when manual_seed has not been called (pytorch#141223)

**Summary**
This PR proposes 4 changes to DTensor RNG management:
1. DTensor allows users to eagerly initialize the RNG tracker by calling `torch.distributed.tensor._random.manual_seed`.
2. DTensor `manual_seed` no longer checks the integrity of the `seed` argument. Users are responsible for setting the same seed on all ranks within an SPMD group, but if there are multiple separate SPMD groups (e.g. across pipeline stages), users should set a _different_ seed for each SPMD group. For cases like Pipeline Parallel, users can set different initial seed for pipelining stages by calling
```
world_mesh = init_device_mesh(
    device_type="cuda",
    mesh_shape=(2, 2, 2),
    mesh_dim_names=("pp", "dp", "tp"),
)
pp_mesh = world_mesh["pp"]
pp_rank = pp_mesh.get_local_rank()
spmd_mesh = world_mesh["dp", "tp"]._flatten("spmd")  # this flattening is only needed if you need to call collective over this mesh
torch.distributed.tensor._random.manual_seed(123+pp_rank, spmd_mesh)
```

In other word, if users want to call `torch.distributed.tensor._random.manual_seed`, they will be responsible for passing in the right value and DTensor won't perform any checks on it. If the current rank is not a part of the mesh, it will use the current device RNG state to initialize.

3. `OffsetBasedRNGTracker` still performs RNG state synchronization by broadcasting the RNG state on rank 0 to `WORLD`. However, calling `torch.distributed.tensor._random.manual_seed` is an exception. In this case, no broadcast will happen.

4. Enforce that the `manual_seed` call only accept "full mesh" i.e. the DTensor RNG state on every rank must be set through the call. This makes sure that no rank has its RNG state left uninitialized and the SPMD ranks have their RNG state synchronous.

**Motivation**
tl;dr

1. Lazily initializing DTensor RNG tracker causes hang in non-SPMD code such as Pipeline Parallel.
2. Users may want to set different seed on ranks in one device mesh.
3. We want to keep the old behavior if users prefer not curating the RNG state and want to have DTensor take care of it.

see detail in pytorch#140301

**Test**
`pytest test/distributed/_tensor/test_random_ops.py`
`pytest test/distributed/tensor/parallel/test_tp_random_state.py`

Pull Request resolved: pytorch#141223
Approved by: https://github.com/wconstab
ghstack dependencies: pytorch#141731, pytorch#141220
GeorgeWigley pushed a commit to graphcore/pytorch-fork that referenced this pull request Nov 29, 2024
… avoid overflow (pytorch#141532)

**Summary**
DTensor RNG code raises error if the seed passed in is beyong `torch.int64` range (e.g. `torch.tensor([2**64-1])` raises error). The solution is to specify the `dtype=torch.uint64` in the `torch.tensor()` call.

Pull Request resolved: pytorch#141532
Approved by: https://github.com/wconstab
ghstack dependencies: pytorch#141731, pytorch#141220, pytorch#141223
pobin6 pushed a commit to pobin6/pytorch that referenced this pull request Dec 5, 2024
…llelRNGTracker since it does not match FSDP2+TP (pytorch#141220)

**Summary**
The ad-hoc DTensor RNG tracker was used to mimic Megatron DDP+TP RNG behavior but it turns out not compatible with PyTorch Distributed FSDP2+TP so we decide to deprecate it and use `OffsetBasedRNGTracker` to replace, which follows the SPMD semantics (replicas get the same random sampling result, shards get different results).

**Motivation**
`TensorParallelRNGTracker` was designed for DDP+TP where the random operators produce the same result along the data parallel mesh dimension and different results along the tensor parallel dimension. However this does not apply to the new FSDP+TP composable combination where the model weights are sharded along data parallel mesh dimension as well. Therefore we decide to remove this outdated RNG tracker type for now. If users have demands for exact match between PyTorch Distributed and Megatron on Random Number generation result, feel free to file an issue.

**Impact**
`TensorParallelRNGTracker` was only used when Tensor Parallel is used (i.e. calling `parallelize_module`).

For non-FSDP users, the "replicas get the same random numbers and shards get different ones" remains unchanged. Unlike `TensorParallelRNGTracker` which sets different seeds (`base_seed + 2718 + TP_rank`) within the TP group, DTensor now sets the same seed (default value is 1234 but users can call `torch.distributed.tensor._random.manual_seed` to modify) on all ranks but choose the right RNG offset based on DTensor placements to enforce the "replicas get the same random numbers and shards get different ones" invariant.

For FSDP2 users, improvement should be observed in a way that DTensor sharded within DP group now gets different random number sampling which `TensorParallelRNGTracker` failed to do, though we're not sure how much this change will improve the eventual training loss convergence.

**Test**
1-d model weight meta init:
`pytest test/distributed/_tensor/test_random_ops.py -s -k test_tp_model_meta_init`

2-d model weight meta init:
`pytest test/distributed/_tensor/test_random_ops.py -s -k test_fsdp_tp_model_meta_init`

TP model weight init test:
`pytest test/distributed/tensor/parallel/test_tp_random_state.py`

FSDP+TP model weight init test:
`pytest test/distributed/_composable/fsdp/test_fully_shard_init.py`

Pull Request resolved: pytorch#141220
Approved by: https://github.com/wconstab
ghstack dependencies: pytorch#141731
pobin6 pushed a commit to pobin6/pytorch that referenced this pull request Dec 5, 2024
…mesh; only sync RNG state in WORLD when manual_seed has not been called (pytorch#141223)

**Summary**
This PR proposes 4 changes to DTensor RNG management:
1. DTensor allows users to eagerly initialize the RNG tracker by calling `torch.distributed.tensor._random.manual_seed`.
2. DTensor `manual_seed` no longer checks the integrity of the `seed` argument. Users are responsible for setting the same seed on all ranks within an SPMD group, but if there are multiple separate SPMD groups (e.g. across pipeline stages), users should set a _different_ seed for each SPMD group. For cases like Pipeline Parallel, users can set different initial seed for pipelining stages by calling
```
world_mesh = init_device_mesh(
    device_type="cuda",
    mesh_shape=(2, 2, 2),
    mesh_dim_names=("pp", "dp", "tp"),
)
pp_mesh = world_mesh["pp"]
pp_rank = pp_mesh.get_local_rank()
spmd_mesh = world_mesh["dp", "tp"]._flatten("spmd")  # this flattening is only needed if you need to call collective over this mesh
torch.distributed.tensor._random.manual_seed(123+pp_rank, spmd_mesh)
```

In other word, if users want to call `torch.distributed.tensor._random.manual_seed`, they will be responsible for passing in the right value and DTensor won't perform any checks on it. If the current rank is not a part of the mesh, it will use the current device RNG state to initialize.

3. `OffsetBasedRNGTracker` still performs RNG state synchronization by broadcasting the RNG state on rank 0 to `WORLD`. However, calling `torch.distributed.tensor._random.manual_seed` is an exception. In this case, no broadcast will happen.

4. Enforce that the `manual_seed` call only accept "full mesh" i.e. the DTensor RNG state on every rank must be set through the call. This makes sure that no rank has its RNG state left uninitialized and the SPMD ranks have their RNG state synchronous.

**Motivation**
tl;dr

1. Lazily initializing DTensor RNG tracker causes hang in non-SPMD code such as Pipeline Parallel.
2. Users may want to set different seed on ranks in one device mesh.
3. We want to keep the old behavior if users prefer not curating the RNG state and want to have DTensor take care of it.

see detail in pytorch#140301

**Test**
`pytest test/distributed/_tensor/test_random_ops.py`
`pytest test/distributed/tensor/parallel/test_tp_random_state.py`

Pull Request resolved: pytorch#141223
Approved by: https://github.com/wconstab
ghstack dependencies: pytorch#141731, pytorch#141220
pobin6 pushed a commit to pobin6/pytorch that referenced this pull request Dec 5, 2024
… avoid overflow (pytorch#141532)

**Summary**
DTensor RNG code raises error if the seed passed in is beyong `torch.int64` range (e.g. `torch.tensor([2**64-1])` raises error). The solution is to specify the `dtype=torch.uint64` in the `torch.tensor()` call.

Pull Request resolved: pytorch#141532
Approved by: https://github.com/wconstab
ghstack dependencies: pytorch#141731, pytorch#141220, pytorch#141223
@github-actions github-actions bot deleted the gh/XilunWu/103/head branch December 30, 2024 02:08
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Merged module: dtensor distributed tensor tag oncall: distributed Add this issue/PR to distributed oncall triage queue release notes: distributed (dtensor) release notes category

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants