KEMBAR78
[FSDP2] Add `set_unshard_in_backward(bool)` by awgu · Pull Request #137922 · pytorch/pytorch · GitHub
Skip to content

Conversation

@awgu
Copy link
Collaborator

@awgu awgu commented Oct 14, 2024

Stack from ghstack (oldest at bottom):

For some expert use cases, the user knows some parameters are not required for backward, so we can skip the unshard in backward. One example is the embedding weight.

cc @XilunWu @H-Huang @kwen2501 @wanchaol @fegin @fduwjj @wz337 @wconstab @d4l3k @c-p-i-o

@pytorch-bot
Copy link

pytorch-bot bot commented Oct 14, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/137922

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit 6796d05 with merge base 0e4d426 (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@pytorch-bot pytorch-bot bot added oncall: distributed Add this issue/PR to distributed oncall triage queue release notes: distributed (fsdp) release notes category labels Oct 14, 2024
@awgu awgu added release notes: distributed (fsdp2) release notes category and removed release notes: distributed (fsdp) release notes category labels Oct 14, 2024
For some expert use cases, the user knows some parameters are not required for backward, so we can skip the unshard in backward. One example is the embedding weight.

cc XilunWu H-Huang kwen2501 wanchaol fegin fduwjj wz337 wconstab d4l3k c-p-i-o

[ghstack-poisoned]
awgu pushed a commit that referenced this pull request Oct 14, 2024
ghstack-source-id: 00710b1
Pull Request resolved: #137922
@awgu
Copy link
Collaborator Author

awgu commented Oct 14, 2024

cc: @weifengpy let me know how you feel about this 😃

Copy link
Contributor

@weifengpy weifengpy left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

make sense to me


@property
def unsharded_param(self) -> nn.Parameter: # ND
self._assert_in_states(ShardedState.UNSHARDED)
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The main change here is that we used to define unsharded_param as a property so that we can assert that it is only accessed when in the unsharded state. However, with this new API, it is possible to get grads on the unsharded parameter without being in the unsharded state (since the unsharded parameter data is not allocated).

We may refactor in the future to get rid of unsharded_param as a property then.

@awgu awgu added the ciflow/trunk Trigger trunk jobs on your pull request label Oct 15, 2024
@awgu awgu marked this pull request as ready for review October 15, 2024 15:58
@awgu
Copy link
Collaborator Author

awgu commented Oct 15, 2024

@pytorchbot merge

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

awgu pushed a commit to pytorch/torchtitan that referenced this pull request Oct 22, 2024
…on 8 GPUs faster"


Requires pytorch/pytorch#137922

```
TORCH_NCCL_AVOID_RECORD_STREAMS=1 PYTORCH_CUDA_ALLOC_CONF="expandable_segments:True" CONFIG_FILE=train_configs/llama3_8b.toml ./run_llama_train.sh 
```

```
[rank0]:2024-10-14 11:58:53,071 - root - INFO - step:  1  loss: 12.2208  memory: 66.44GiB(69.93%)  wps: 882  mfu: 5.17%
[rank0]:2024-10-14 11:58:53,071 - root - INFO - Synchronizing and adjusting timeout for all ProcessGroups to 0:01:40
[rank0]:2024-10-14 11:58:54,196 - root - INFO - step:  2  loss: 12.0630  memory: 73.96GiB(77.85%)  wps: 7,282  mfu: 42.64%
[rank0]:2024-10-14 11:58:55,322 - root - INFO - step:  3  loss: 11.7272  memory: 73.96GiB(77.85%)  wps: 7,276  mfu: 42.60%
[rank0]:2024-10-14 11:58:56,448 - root - INFO - step:  4  loss: 11.2526  memory: 73.96GiB(77.85%)  wps: 7,280  mfu: 42.63%
[rank0]:2024-10-14 11:58:57,575 - root - INFO - step:  5  loss: 10.7972  memory: 73.96GiB(77.85%)  wps: 7,268  mfu: 42.56%
[rank0]:2024-10-14 11:58:58,699 - root - INFO - step:  6  loss: 10.5048  memory: 73.96GiB(77.85%)  wps: 7,293  mfu: 42.70%
[rank0]:2024-10-14 11:58:59,824 - root - INFO - step:  7  loss: 10.3384  memory: 73.96GiB(77.85%)  wps: 7,285  mfu: 42.66%
[rank0]:2024-10-14 11:59:00,952 - root - INFO - step:  8  loss: 10.3164  memory: 73.96GiB(77.85%)  wps: 7,266  mfu: 42.55%
[rank0]:2024-10-14 11:59:02,083 - root - INFO - step:  9  loss: 10.0995  memory: 73.96GiB(77.85%)  wps: 7,247  mfu: 42.44%
[rank0]:2024-10-14 11:59:03,211 - root - INFO - step: 10  loss:  9.9308  memory: 73.96GiB(77.85%)  wps: 7,264  mfu: 42.54%
[rank0]:2024-10-14 11:59:04,337 - root - INFO - step: 11  loss:  9.5785  memory: 73.96GiB(77.85%)  wps: 7,275  mfu: 42.60%
[rank0]:2024-10-14 11:59:05,465 - root - INFO - step: 12  loss:  9.5265  memory: 73.96GiB(77.85%)  wps: 7,267  mfu: 42.56%
[rank0]:2024-10-14 11:59:06,595 - root - INFO - step: 13  loss:  9.3497  memory: 73.96GiB(77.85%)  wps: 7,252  mfu: 42.47%
[rank0]:2024-10-14 11:59:06,601 - root - WARNING - Dataset c4_test is being re-looped
```

[ghstack-poisoned]
awgu pushed a commit to pytorch/torchtitan that referenced this pull request Oct 22, 2024
Requires pytorch/pytorch#137922

```
TORCH_NCCL_AVOID_RECORD_STREAMS=1 PYTORCH_CUDA_ALLOC_CONF="expandable_segments:True" CONFIG_FILE=train_configs/llama3_8b.toml ./run_llama_train.sh 
```

```
[rank0]:2024-10-14 11:58:53,071 - root - INFO - step:  1  loss: 12.2208  memory: 66.44GiB(69.93%)  wps: 882  mfu: 5.17%
[rank0]:2024-10-14 11:58:53,071 - root - INFO - Synchronizing and adjusting timeout for all ProcessGroups to 0:01:40
[rank0]:2024-10-14 11:58:54,196 - root - INFO - step:  2  loss: 12.0630  memory: 73.96GiB(77.85%)  wps: 7,282  mfu: 42.64%
[rank0]:2024-10-14 11:58:55,322 - root - INFO - step:  3  loss: 11.7272  memory: 73.96GiB(77.85%)  wps: 7,276  mfu: 42.60%
[rank0]:2024-10-14 11:58:56,448 - root - INFO - step:  4  loss: 11.2526  memory: 73.96GiB(77.85%)  wps: 7,280  mfu: 42.63%
[rank0]:2024-10-14 11:58:57,575 - root - INFO - step:  5  loss: 10.7972  memory: 73.96GiB(77.85%)  wps: 7,268  mfu: 42.56%
[rank0]:2024-10-14 11:58:58,699 - root - INFO - step:  6  loss: 10.5048  memory: 73.96GiB(77.85%)  wps: 7,293  mfu: 42.70%
[rank0]:2024-10-14 11:58:59,824 - root - INFO - step:  7  loss: 10.3384  memory: 73.96GiB(77.85%)  wps: 7,285  mfu: 42.66%
[rank0]:2024-10-14 11:59:00,952 - root - INFO - step:  8  loss: 10.3164  memory: 73.96GiB(77.85%)  wps: 7,266  mfu: 42.55%
[rank0]:2024-10-14 11:59:02,083 - root - INFO - step:  9  loss: 10.0995  memory: 73.96GiB(77.85%)  wps: 7,247  mfu: 42.44%
[rank0]:2024-10-14 11:59:03,211 - root - INFO - step: 10  loss:  9.9308  memory: 73.96GiB(77.85%)  wps: 7,264  mfu: 42.54%
[rank0]:2024-10-14 11:59:04,337 - root - INFO - step: 11  loss:  9.5785  memory: 73.96GiB(77.85%)  wps: 7,275  mfu: 42.60%
[rank0]:2024-10-14 11:59:05,465 - root - INFO - step: 12  loss:  9.5265  memory: 73.96GiB(77.85%)  wps: 7,267  mfu: 42.56%
[rank0]:2024-10-14 11:59:06,595 - root - INFO - step: 13  loss:  9.3497  memory: 73.96GiB(77.85%)  wps: 7,252  mfu: 42.47%
[rank0]:2024-10-14 11:59:06,601 - root - WARNING - Dataset c4_test is being re-looped
```

[ghstack-poisoned]
awgu pushed a commit to pytorch/torchtitan that referenced this pull request Oct 22, 2024
…on 8 GPUs faster"


Requires pytorch/pytorch#137922

```
TORCH_NCCL_AVOID_RECORD_STREAMS=1 PYTORCH_CUDA_ALLOC_CONF="expandable_segments:True" CONFIG_FILE=train_configs/llama3_8b.toml ./run_llama_train.sh 
```

```
[rank0]:2024-10-21 21:23:32,899 - root - INFO - Training starts at step 1, with local batch size 1, global batch size 8, sequence length 8192, total steps 1000 (warmup 200)
[rank0]:/data/users/andgu/pytorch/torch/_inductor/lowering.py:1759: UserWarning: Torchinductor does not support code generation for complex operators. Performance may be worse than eager.
[rank0]:  warnings.warn(
[rank0]:/data/users/andgu/pytorch/torch/autograd/graph.py:825: UserWarning: cuDNN SDPA backward got grad_output.strides() != output.strides(), attempting to materialize a grad_output with matching strides... (Triggered internally at /data/users/andgu/pytorch/aten/src/ATen/native/cudnn/MHA.cpp:674.)
[rank0]:  return Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
[rank0]:2024-10-21 21:23:42,336 - root - INFO - step:  1  loss: 12.2799  memory: 63.45GiB(66.79%)  wps: 868  mfu: 5.08%
[rank0]:2024-10-21 21:23:42,336 - root - INFO - Synchronizing and adjusting timeout for all ProcessGroups to 0:01:40
[rank0]:2024-10-21 21:23:43,539 - root - INFO - step:  2  loss: 12.1023  memory: 70.96GiB(74.69%)  wps: 6,813  mfu: 39.90%
[rank0]:2024-10-21 21:23:44,667 - root - INFO - step:  3  loss: 11.7899  memory: 70.96GiB(74.69%)  wps: 7,263  mfu: 42.53%
[rank0]:2024-10-21 21:23:45,795 - root - INFO - step:  4  loss: 11.3163  memory: 70.96GiB(74.69%)  wps: 7,264  mfu: 42.54%
[rank0]:2024-10-21 21:23:46,923 - root - INFO - step:  5  loss: 10.8908  memory: 70.96GiB(74.69%)  wps: 7,262  mfu: 42.52%
[rank0]:2024-10-21 21:23:48,050 - root - INFO - step:  6  loss: 10.4146  memory: 70.96GiB(74.69%)  wps: 7,275  mfu: 42.60%
[rank0]:2024-10-21 21:23:49,174 - root - INFO - step:  7  loss: 10.1523  memory: 70.96GiB(74.69%)  wps: 7,288  mfu: 42.68%
[rank0]:2024-10-21 21:23:50,306 - root - INFO - step:  8  loss: 10.2847  memory: 70.96GiB(74.69%)  wps: 7,240  mfu: 42.40%
[rank0]:2024-10-21 21:23:51,434 - root - INFO - step:  9  loss: 10.0047  memory: 70.96GiB(74.69%)  wps: 7,263  mfu: 42.53%
[rank0]:2024-10-21 21:23:52,560 - root - INFO - step: 10  loss:  9.9882  memory: 70.96GiB(74.69%)  wps: 7,279  mfu: 42.63%
[rank0]:2024-10-21 21:23:53,685 - root - INFO - step: 11  loss:  9.6261  memory: 70.96GiB(74.69%)  wps: 7,285  mfu: 42.66%
[rank0]:2024-10-21 21:23:54,813 - root - INFO - step: 12  loss:  9.5229  memory: 70.96GiB(74.69%)  wps: 7,265  mfu: 42.54%
[rank0]:2024-10-21 21:23:55,944 - root - INFO - step: 13  loss:  9.4371  memory: 70.96GiB(74.69%)  wps: 7,244  mfu: 42.42%
[rank0]:2024-10-21 21:23:55,950 - root - WARNING - Dataset c4_test is being re-looped
```

[ghstack-poisoned]
awgu pushed a commit to pytorch/torchtitan that referenced this pull request Oct 22, 2024
Requires pytorch/pytorch#137922

```
TORCH_NCCL_AVOID_RECORD_STREAMS=1 PYTORCH_CUDA_ALLOC_CONF="expandable_segments:True" CONFIG_FILE=train_configs/llama3_8b.toml ./run_llama_train.sh 
```

```
[rank0]:2024-10-21 21:23:32,899 - root - INFO - Training starts at step 1, with local batch size 1, global batch size 8, sequence length 8192, total steps 1000 (warmup 200)
[rank0]:/data/users/andgu/pytorch/torch/_inductor/lowering.py:1759: UserWarning: Torchinductor does not support code generation for complex operators. Performance may be worse than eager.
[rank0]:  warnings.warn(
[rank0]:/data/users/andgu/pytorch/torch/autograd/graph.py:825: UserWarning: cuDNN SDPA backward got grad_output.strides() != output.strides(), attempting to materialize a grad_output with matching strides... (Triggered internally at /data/users/andgu/pytorch/aten/src/ATen/native/cudnn/MHA.cpp:674.)
[rank0]:  return Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
[rank0]:2024-10-21 21:23:42,336 - root - INFO - step:  1  loss: 12.2799  memory: 63.45GiB(66.79%)  wps: 868  mfu: 5.08%
[rank0]:2024-10-21 21:23:42,336 - root - INFO - Synchronizing and adjusting timeout for all ProcessGroups to 0:01:40
[rank0]:2024-10-21 21:23:43,539 - root - INFO - step:  2  loss: 12.1023  memory: 70.96GiB(74.69%)  wps: 6,813  mfu: 39.90%
[rank0]:2024-10-21 21:23:44,667 - root - INFO - step:  3  loss: 11.7899  memory: 70.96GiB(74.69%)  wps: 7,263  mfu: 42.53%
[rank0]:2024-10-21 21:23:45,795 - root - INFO - step:  4  loss: 11.3163  memory: 70.96GiB(74.69%)  wps: 7,264  mfu: 42.54%
[rank0]:2024-10-21 21:23:46,923 - root - INFO - step:  5  loss: 10.8908  memory: 70.96GiB(74.69%)  wps: 7,262  mfu: 42.52%
[rank0]:2024-10-21 21:23:48,050 - root - INFO - step:  6  loss: 10.4146  memory: 70.96GiB(74.69%)  wps: 7,275  mfu: 42.60%
[rank0]:2024-10-21 21:23:49,174 - root - INFO - step:  7  loss: 10.1523  memory: 70.96GiB(74.69%)  wps: 7,288  mfu: 42.68%
[rank0]:2024-10-21 21:23:50,306 - root - INFO - step:  8  loss: 10.2847  memory: 70.96GiB(74.69%)  wps: 7,240  mfu: 42.40%
[rank0]:2024-10-21 21:23:51,434 - root - INFO - step:  9  loss: 10.0047  memory: 70.96GiB(74.69%)  wps: 7,263  mfu: 42.53%
[rank0]:2024-10-21 21:23:52,560 - root - INFO - step: 10  loss:  9.9882  memory: 70.96GiB(74.69%)  wps: 7,279  mfu: 42.63%
[rank0]:2024-10-21 21:23:53,685 - root - INFO - step: 11  loss:  9.6261  memory: 70.96GiB(74.69%)  wps: 7,285  mfu: 42.66%
[rank0]:2024-10-21 21:23:54,813 - root - INFO - step: 12  loss:  9.5229  memory: 70.96GiB(74.69%)  wps: 7,265  mfu: 42.54%
[rank0]:2024-10-21 21:23:55,944 - root - INFO - step: 13  loss:  9.4371  memory: 70.96GiB(74.69%)  wps: 7,244  mfu: 42.42%
[rank0]:2024-10-21 21:23:55,950 - root - WARNING - Dataset c4_test is being re-looped
```

[ghstack-poisoned]
@github-actions github-actions bot deleted the gh/awgu/654/head branch November 15, 2024 02:09
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ciflow/trunk Trigger trunk jobs on your pull request Merged oncall: distributed Add this issue/PR to distributed oncall triage queue release notes: distributed (fsdp2) release notes category

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants