-
Notifications
You must be signed in to change notification settings - Fork 25.7k
[FSDP2] Add set_unshard_in_backward(bool)
#137922
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
[ghstack-poisoned]
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/137922
Note: Links to docs will display an error until the docs builds have been completed. ✅ No FailuresAs of commit 6796d05 with merge base 0e4d426 ( This comment was automatically generated by Dr. CI and updates every 15 minutes. |
For some expert use cases, the user knows some parameters are not required for backward, so we can skip the unshard in backward. One example is the embedding weight. cc XilunWu H-Huang kwen2501 wanchaol fegin fduwjj wz337 wconstab d4l3k c-p-i-o [ghstack-poisoned]
|
cc: @weifengpy let me know how you feel about this 😃 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
make sense to me
|
|
||
| @property | ||
| def unsharded_param(self) -> nn.Parameter: # ND | ||
| self._assert_in_states(ShardedState.UNSHARDED) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The main change here is that we used to define unsharded_param as a property so that we can assert that it is only accessed when in the unsharded state. However, with this new API, it is possible to get grads on the unsharded parameter without being in the unsharded state (since the unsharded parameter data is not allocated).
We may refactor in the future to get rid of unsharded_param as a property then.
|
@pytorchbot merge |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
…on 8 GPUs faster" Requires pytorch/pytorch#137922 ``` TORCH_NCCL_AVOID_RECORD_STREAMS=1 PYTORCH_CUDA_ALLOC_CONF="expandable_segments:True" CONFIG_FILE=train_configs/llama3_8b.toml ./run_llama_train.sh ``` ``` [rank0]:2024-10-14 11:58:53,071 - root - INFO - step: 1 loss: 12.2208 memory: 66.44GiB(69.93%) wps: 882 mfu: 5.17% [rank0]:2024-10-14 11:58:53,071 - root - INFO - Synchronizing and adjusting timeout for all ProcessGroups to 0:01:40 [rank0]:2024-10-14 11:58:54,196 - root - INFO - step: 2 loss: 12.0630 memory: 73.96GiB(77.85%) wps: 7,282 mfu: 42.64% [rank0]:2024-10-14 11:58:55,322 - root - INFO - step: 3 loss: 11.7272 memory: 73.96GiB(77.85%) wps: 7,276 mfu: 42.60% [rank0]:2024-10-14 11:58:56,448 - root - INFO - step: 4 loss: 11.2526 memory: 73.96GiB(77.85%) wps: 7,280 mfu: 42.63% [rank0]:2024-10-14 11:58:57,575 - root - INFO - step: 5 loss: 10.7972 memory: 73.96GiB(77.85%) wps: 7,268 mfu: 42.56% [rank0]:2024-10-14 11:58:58,699 - root - INFO - step: 6 loss: 10.5048 memory: 73.96GiB(77.85%) wps: 7,293 mfu: 42.70% [rank0]:2024-10-14 11:58:59,824 - root - INFO - step: 7 loss: 10.3384 memory: 73.96GiB(77.85%) wps: 7,285 mfu: 42.66% [rank0]:2024-10-14 11:59:00,952 - root - INFO - step: 8 loss: 10.3164 memory: 73.96GiB(77.85%) wps: 7,266 mfu: 42.55% [rank0]:2024-10-14 11:59:02,083 - root - INFO - step: 9 loss: 10.0995 memory: 73.96GiB(77.85%) wps: 7,247 mfu: 42.44% [rank0]:2024-10-14 11:59:03,211 - root - INFO - step: 10 loss: 9.9308 memory: 73.96GiB(77.85%) wps: 7,264 mfu: 42.54% [rank0]:2024-10-14 11:59:04,337 - root - INFO - step: 11 loss: 9.5785 memory: 73.96GiB(77.85%) wps: 7,275 mfu: 42.60% [rank0]:2024-10-14 11:59:05,465 - root - INFO - step: 12 loss: 9.5265 memory: 73.96GiB(77.85%) wps: 7,267 mfu: 42.56% [rank0]:2024-10-14 11:59:06,595 - root - INFO - step: 13 loss: 9.3497 memory: 73.96GiB(77.85%) wps: 7,252 mfu: 42.47% [rank0]:2024-10-14 11:59:06,601 - root - WARNING - Dataset c4_test is being re-looped ``` [ghstack-poisoned]
Requires pytorch/pytorch#137922 ``` TORCH_NCCL_AVOID_RECORD_STREAMS=1 PYTORCH_CUDA_ALLOC_CONF="expandable_segments:True" CONFIG_FILE=train_configs/llama3_8b.toml ./run_llama_train.sh ``` ``` [rank0]:2024-10-14 11:58:53,071 - root - INFO - step: 1 loss: 12.2208 memory: 66.44GiB(69.93%) wps: 882 mfu: 5.17% [rank0]:2024-10-14 11:58:53,071 - root - INFO - Synchronizing and adjusting timeout for all ProcessGroups to 0:01:40 [rank0]:2024-10-14 11:58:54,196 - root - INFO - step: 2 loss: 12.0630 memory: 73.96GiB(77.85%) wps: 7,282 mfu: 42.64% [rank0]:2024-10-14 11:58:55,322 - root - INFO - step: 3 loss: 11.7272 memory: 73.96GiB(77.85%) wps: 7,276 mfu: 42.60% [rank0]:2024-10-14 11:58:56,448 - root - INFO - step: 4 loss: 11.2526 memory: 73.96GiB(77.85%) wps: 7,280 mfu: 42.63% [rank0]:2024-10-14 11:58:57,575 - root - INFO - step: 5 loss: 10.7972 memory: 73.96GiB(77.85%) wps: 7,268 mfu: 42.56% [rank0]:2024-10-14 11:58:58,699 - root - INFO - step: 6 loss: 10.5048 memory: 73.96GiB(77.85%) wps: 7,293 mfu: 42.70% [rank0]:2024-10-14 11:58:59,824 - root - INFO - step: 7 loss: 10.3384 memory: 73.96GiB(77.85%) wps: 7,285 mfu: 42.66% [rank0]:2024-10-14 11:59:00,952 - root - INFO - step: 8 loss: 10.3164 memory: 73.96GiB(77.85%) wps: 7,266 mfu: 42.55% [rank0]:2024-10-14 11:59:02,083 - root - INFO - step: 9 loss: 10.0995 memory: 73.96GiB(77.85%) wps: 7,247 mfu: 42.44% [rank0]:2024-10-14 11:59:03,211 - root - INFO - step: 10 loss: 9.9308 memory: 73.96GiB(77.85%) wps: 7,264 mfu: 42.54% [rank0]:2024-10-14 11:59:04,337 - root - INFO - step: 11 loss: 9.5785 memory: 73.96GiB(77.85%) wps: 7,275 mfu: 42.60% [rank0]:2024-10-14 11:59:05,465 - root - INFO - step: 12 loss: 9.5265 memory: 73.96GiB(77.85%) wps: 7,267 mfu: 42.56% [rank0]:2024-10-14 11:59:06,595 - root - INFO - step: 13 loss: 9.3497 memory: 73.96GiB(77.85%) wps: 7,252 mfu: 42.47% [rank0]:2024-10-14 11:59:06,601 - root - WARNING - Dataset c4_test is being re-looped ``` [ghstack-poisoned]
…on 8 GPUs faster" Requires pytorch/pytorch#137922 ``` TORCH_NCCL_AVOID_RECORD_STREAMS=1 PYTORCH_CUDA_ALLOC_CONF="expandable_segments:True" CONFIG_FILE=train_configs/llama3_8b.toml ./run_llama_train.sh ``` ``` [rank0]:2024-10-21 21:23:32,899 - root - INFO - Training starts at step 1, with local batch size 1, global batch size 8, sequence length 8192, total steps 1000 (warmup 200) [rank0]:/data/users/andgu/pytorch/torch/_inductor/lowering.py:1759: UserWarning: Torchinductor does not support code generation for complex operators. Performance may be worse than eager. [rank0]: warnings.warn( [rank0]:/data/users/andgu/pytorch/torch/autograd/graph.py:825: UserWarning: cuDNN SDPA backward got grad_output.strides() != output.strides(), attempting to materialize a grad_output with matching strides... (Triggered internally at /data/users/andgu/pytorch/aten/src/ATen/native/cudnn/MHA.cpp:674.) [rank0]: return Variable._execution_engine.run_backward( # Calls into the C++ engine to run the backward pass [rank0]:2024-10-21 21:23:42,336 - root - INFO - step: 1 loss: 12.2799 memory: 63.45GiB(66.79%) wps: 868 mfu: 5.08% [rank0]:2024-10-21 21:23:42,336 - root - INFO - Synchronizing and adjusting timeout for all ProcessGroups to 0:01:40 [rank0]:2024-10-21 21:23:43,539 - root - INFO - step: 2 loss: 12.1023 memory: 70.96GiB(74.69%) wps: 6,813 mfu: 39.90% [rank0]:2024-10-21 21:23:44,667 - root - INFO - step: 3 loss: 11.7899 memory: 70.96GiB(74.69%) wps: 7,263 mfu: 42.53% [rank0]:2024-10-21 21:23:45,795 - root - INFO - step: 4 loss: 11.3163 memory: 70.96GiB(74.69%) wps: 7,264 mfu: 42.54% [rank0]:2024-10-21 21:23:46,923 - root - INFO - step: 5 loss: 10.8908 memory: 70.96GiB(74.69%) wps: 7,262 mfu: 42.52% [rank0]:2024-10-21 21:23:48,050 - root - INFO - step: 6 loss: 10.4146 memory: 70.96GiB(74.69%) wps: 7,275 mfu: 42.60% [rank0]:2024-10-21 21:23:49,174 - root - INFO - step: 7 loss: 10.1523 memory: 70.96GiB(74.69%) wps: 7,288 mfu: 42.68% [rank0]:2024-10-21 21:23:50,306 - root - INFO - step: 8 loss: 10.2847 memory: 70.96GiB(74.69%) wps: 7,240 mfu: 42.40% [rank0]:2024-10-21 21:23:51,434 - root - INFO - step: 9 loss: 10.0047 memory: 70.96GiB(74.69%) wps: 7,263 mfu: 42.53% [rank0]:2024-10-21 21:23:52,560 - root - INFO - step: 10 loss: 9.9882 memory: 70.96GiB(74.69%) wps: 7,279 mfu: 42.63% [rank0]:2024-10-21 21:23:53,685 - root - INFO - step: 11 loss: 9.6261 memory: 70.96GiB(74.69%) wps: 7,285 mfu: 42.66% [rank0]:2024-10-21 21:23:54,813 - root - INFO - step: 12 loss: 9.5229 memory: 70.96GiB(74.69%) wps: 7,265 mfu: 42.54% [rank0]:2024-10-21 21:23:55,944 - root - INFO - step: 13 loss: 9.4371 memory: 70.96GiB(74.69%) wps: 7,244 mfu: 42.42% [rank0]:2024-10-21 21:23:55,950 - root - WARNING - Dataset c4_test is being re-looped ``` [ghstack-poisoned]
Requires pytorch/pytorch#137922 ``` TORCH_NCCL_AVOID_RECORD_STREAMS=1 PYTORCH_CUDA_ALLOC_CONF="expandable_segments:True" CONFIG_FILE=train_configs/llama3_8b.toml ./run_llama_train.sh ``` ``` [rank0]:2024-10-21 21:23:32,899 - root - INFO - Training starts at step 1, with local batch size 1, global batch size 8, sequence length 8192, total steps 1000 (warmup 200) [rank0]:/data/users/andgu/pytorch/torch/_inductor/lowering.py:1759: UserWarning: Torchinductor does not support code generation for complex operators. Performance may be worse than eager. [rank0]: warnings.warn( [rank0]:/data/users/andgu/pytorch/torch/autograd/graph.py:825: UserWarning: cuDNN SDPA backward got grad_output.strides() != output.strides(), attempting to materialize a grad_output with matching strides... (Triggered internally at /data/users/andgu/pytorch/aten/src/ATen/native/cudnn/MHA.cpp:674.) [rank0]: return Variable._execution_engine.run_backward( # Calls into the C++ engine to run the backward pass [rank0]:2024-10-21 21:23:42,336 - root - INFO - step: 1 loss: 12.2799 memory: 63.45GiB(66.79%) wps: 868 mfu: 5.08% [rank0]:2024-10-21 21:23:42,336 - root - INFO - Synchronizing and adjusting timeout for all ProcessGroups to 0:01:40 [rank0]:2024-10-21 21:23:43,539 - root - INFO - step: 2 loss: 12.1023 memory: 70.96GiB(74.69%) wps: 6,813 mfu: 39.90% [rank0]:2024-10-21 21:23:44,667 - root - INFO - step: 3 loss: 11.7899 memory: 70.96GiB(74.69%) wps: 7,263 mfu: 42.53% [rank0]:2024-10-21 21:23:45,795 - root - INFO - step: 4 loss: 11.3163 memory: 70.96GiB(74.69%) wps: 7,264 mfu: 42.54% [rank0]:2024-10-21 21:23:46,923 - root - INFO - step: 5 loss: 10.8908 memory: 70.96GiB(74.69%) wps: 7,262 mfu: 42.52% [rank0]:2024-10-21 21:23:48,050 - root - INFO - step: 6 loss: 10.4146 memory: 70.96GiB(74.69%) wps: 7,275 mfu: 42.60% [rank0]:2024-10-21 21:23:49,174 - root - INFO - step: 7 loss: 10.1523 memory: 70.96GiB(74.69%) wps: 7,288 mfu: 42.68% [rank0]:2024-10-21 21:23:50,306 - root - INFO - step: 8 loss: 10.2847 memory: 70.96GiB(74.69%) wps: 7,240 mfu: 42.40% [rank0]:2024-10-21 21:23:51,434 - root - INFO - step: 9 loss: 10.0047 memory: 70.96GiB(74.69%) wps: 7,263 mfu: 42.53% [rank0]:2024-10-21 21:23:52,560 - root - INFO - step: 10 loss: 9.9882 memory: 70.96GiB(74.69%) wps: 7,279 mfu: 42.63% [rank0]:2024-10-21 21:23:53,685 - root - INFO - step: 11 loss: 9.6261 memory: 70.96GiB(74.69%) wps: 7,285 mfu: 42.66% [rank0]:2024-10-21 21:23:54,813 - root - INFO - step: 12 loss: 9.5229 memory: 70.96GiB(74.69%) wps: 7,265 mfu: 42.54% [rank0]:2024-10-21 21:23:55,944 - root - INFO - step: 13 loss: 9.4371 memory: 70.96GiB(74.69%) wps: 7,244 mfu: 42.42% [rank0]:2024-10-21 21:23:55,950 - root - WARNING - Dataset c4_test is being re-looped ``` [ghstack-poisoned]
Stack from ghstack (oldest at bottom):
set_unshard_in_backward(bool)#137922For some expert use cases, the user knows some parameters are not required for backward, so we can skip the unshard in backward. One example is the embedding weight.
cc @XilunWu @H-Huang @kwen2501 @wanchaol @fegin @fduwjj @wz337 @wconstab @d4l3k @c-p-i-o