KEMBAR78
Update real device in FSDP state_dict_utils by ankurneog · Pull Request #134994 · pytorch/pytorch · GitHub
Skip to content

Conversation

@ankurneog
Copy link

@ankurneog ankurneog commented Sep 3, 2024

Motivation

The default device for tensor.device both for sharded as well as non sharded is set to cuda by default. Hence while checking the FSDP UTs we see the following errors. This change updates the actual device type based on the created tensor.

[rank3]   File "/root/repos/pytorch-training-tests/tests/pytorch/v2.4.0/distributed_hpu/fsdp/test_fsdp_dtensor_state_dict.py", line 143, in test_dtensor_sharded_tensor_state_dict_identical
[rank3]     sharded_tensor_sd = ref_model.state_dict()
[rank3]   File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1944, in state_dict
[rank3]     hook_result = hook(self, destination, prefix, local_metadata)
[rank3]   File "/usr/local/lib/python3.10/dist-packages/torch/utils/_contextlib.py", line 116, in decorate_context
[rank3]     return func(*args, **kwargs)
[rank3]   File "/usr/local/lib/python3.10/dist-packages/torch/distributed/fsdp/_state_dict_utils.py", line 752, in _post_state_dict_hook
[rank3]     tensor.device,
[rank3]   File "/usr/local/lib/python3.10/dist-packages/typing_extensions.py", line 2853, in wrapper
[rank3]     return arg(*args, **kwargs)
[rank3]   File "/usr/local/lib/python3.10/dist-packages/torch/distributed/_shard/sharded_tensor/api.py", line 1152, in __torch_function__
[rank3]     return dispatch(st_instance, func)
[rank3]   File "/usr/local/lib/python3.10/dist-packages/torch/distributed/_shard/sharded_tensor/api.py", line 1134, in dispatch
[rank3]     return _SHARDED_OPS[func](types, args, kwargs, st._process_group)
[rank3]   File "/usr/local/lib/python3.10/dist-packages/torch/distributed/_shard/op_registry_utils.py", line 33, in wrapper
[rank3]     return wrapped_func(types, args, kwargs, process_group)
[rank3]   File "/usr/local/lib/python3.10/dist-packages/torch/distributed/_shard/sharded_tensor/_ops/tensor_ops.py", line 52, in tensor_device
[rank3]     dev = torch.device(torch.cuda.current_device())
[rank3]   File "/usr/local/lib/python3.10/dist-packages/torch/cuda/__init__.py", line 878, in current_device
[rank3]     _lazy_init()
[rank3]   File "/usr/local/lib/python3.10/dist-packages/torch/cuda/__init__.py", line 305, in _lazy_init
[rank3]     raise AssertionError("Torch not compiled with CUDA enabled")
[rank3] AssertionError: Torch not compiled with CUDA enabled 

cc @XilunWu @H-Huang @awgu @kwen2501 @wanchaol @fegin @fduwjj @wz337 @wconstab @d4l3k @c-p-i-o

@pytorch-bot
Copy link

pytorch-bot bot commented Sep 3, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/134994

Note: Links to docs will display an error until the docs builds have been completed.

❌ 1 New Failure, 2 Unrelated Failures

As of commit c8fc022 with merge base c977bb7 (image):

NEW FAILURE - The following job has failed:

FLAKY - The following job failed but was likely due to flakiness present on trunk:

UNSTABLE - The following job failed but was likely due to flakiness present on trunk and has been marked as unstable:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@pytorch-bot pytorch-bot bot added oncall: distributed Add this issue/PR to distributed oncall triage queue release notes: distributed (fsdp) release notes category labels Sep 3, 2024
@fegin
Copy link
Contributor

fegin commented Sep 3, 2024

@ankurneog Thanks for the PR. Curious about the use case, my understanding is FSDP cannot be used with CPU only environment?

@colesbury colesbury requested a review from wanchaol September 3, 2024 18:58
@colesbury colesbury added the triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module label Sep 3, 2024
@ankurneog
Copy link
Author

@ankurneog Thanks for the PR. Curious about the use case, my understanding is FSDP cannot be used with CPU only environment?

@fegin : this is for intel Gaudi /HPU device

@ankurneog ankurneog force-pushed the fsdp_state_dict_device branch 2 times, most recently from b246e4e to c05d748 Compare September 4, 2024 16:08
@ankurneog
Copy link
Author

@hippocookie : can you please help with the approval. Thanks

@zeshengzong
Copy link
Contributor

@hippocookie : can you please help with the approval. Thanks

Sorry I don't have permission to do that, need help from @fegin :D

@ankurneog ankurneog force-pushed the fsdp_state_dict_device branch from c05d748 to c824c7d Compare September 6, 2024 03:51
@ankurneog
Copy link
Author

@fegin : Could you please help with the approval. thanks.

@fegin fegin added ciflow/trunk Trigger trunk jobs on your pull request ciflow/periodic Trigger jobs ran periodically on master (periodic.yml) on the PR labels Sep 9, 2024
@fegin
Copy link
Contributor

fegin commented Sep 9, 2024

Let me initial the tests and see if this PR breaks the existing tests before stamping the PR.

@ankurneog
Copy link
Author

@fegin : gentle reminder, could you please help with the approval. thanks.

@fegin
Copy link
Contributor

fegin commented Sep 13, 2024

Can you rebase and resubmit again? There are too many noise in the CI. Thanks!

@wz337
Copy link
Contributor

wz337 commented Sep 14, 2024

@pytorchmergebot rebase

@pytorchmergebot
Copy link
Collaborator

@pytorchbot started a rebase job onto refs/remotes/origin/viable/strict. Check the current status here

@pytorchmergebot
Copy link
Collaborator

Successfully rebased fsdp_state_dict_device onto refs/remotes/origin/viable/strict, please pull locally before adding more changes (for example, via git checkout fsdp_state_dict_device && git pull --rebase)

@fegin
Copy link
Contributor

fegin commented Sep 16, 2024

@pytorchbot merge

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

@pytorchmergebot
Copy link
Collaborator

Merge failed

Reason: 1 jobs have failed, first few of them are: periodic / ios-build-test / build (default, 1, 1, macos-14-xlarge, SIMULATOR, arm64, 1, 0, 1)

Details for Dev Infra team Raised by workflow job

@ankurneog
Copy link
Author

@pytorchbot rebase

@pytorchmergebot
Copy link
Collaborator

@pytorchbot started a rebase job onto refs/remotes/origin/viable/strict. Check the current status here

@pytorchmergebot
Copy link
Collaborator

Successfully rebased fsdp_state_dict_device onto refs/remotes/origin/viable/strict, please pull locally before adding more changes (for example, via git checkout fsdp_state_dict_device && git pull --rebase)

@ankurneog
Copy link
Author

@pytorchbot merge

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

@pytorchmergebot
Copy link
Collaborator

Merge failed

Reason: 1 jobs have failed, first few of them are: periodic / ios-build-test / build (default, 1, 1, macos-14-xlarge, SIMULATOR, arm64, 1, 0, 1)

Details for Dev Infra team Raised by workflow job

@ankurneog
Copy link
Author

@fegin : I believe the failures are not related to my change, can you please help with the merge

@fegin
Copy link
Contributor

fegin commented Sep 17, 2024

@pytorchbot merge -f "The failing test is not related."

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged immediately since you used the force (-f) flag, bypassing any CI checks (ETA: 1-5 minutes). Please use -f as last resort and instead consider -i/--ignore-current to continue the merge ignoring current failures. This will allow currently pending tests to finish and report signal before the merge.

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

Chao1Han pushed a commit to Chao1Han/pytorch that referenced this pull request Sep 20, 2024
## Motivation
The default device for tensor.device both for sharded as well as non sharded is set to cuda by default. Hence while checking the FSDP UTs we see the following errors. This change updates the actual device type based on the created tensor.

```
[rank3]   File "/root/repos/pytorch-training-tests/tests/pytorch/v2.4.0/distributed_hpu/fsdp/test_fsdp_dtensor_state_dict.py", line 143, in test_dtensor_sharded_tensor_state_dict_identical
[rank3]     sharded_tensor_sd = ref_model.state_dict()
[rank3]   File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1944, in state_dict
[rank3]     hook_result = hook(self, destination, prefix, local_metadata)
[rank3]   File "/usr/local/lib/python3.10/dist-packages/torch/utils/_contextlib.py", line 116, in decorate_context
[rank3]     return func(*args, **kwargs)
[rank3]   File "/usr/local/lib/python3.10/dist-packages/torch/distributed/fsdp/_state_dict_utils.py", line 752, in _post_state_dict_hook
[rank3]     tensor.device,
[rank3]   File "/usr/local/lib/python3.10/dist-packages/typing_extensions.py", line 2853, in wrapper
[rank3]     return arg(*args, **kwargs)
[rank3]   File "/usr/local/lib/python3.10/dist-packages/torch/distributed/_shard/sharded_tensor/api.py", line 1152, in __torch_function__
[rank3]     return dispatch(st_instance, func)
[rank3]   File "/usr/local/lib/python3.10/dist-packages/torch/distributed/_shard/sharded_tensor/api.py", line 1134, in dispatch
[rank3]     return _SHARDED_OPS[func](types, args, kwargs, st._process_group)
[rank3]   File "/usr/local/lib/python3.10/dist-packages/torch/distributed/_shard/op_registry_utils.py", line 33, in wrapper
[rank3]     return wrapped_func(types, args, kwargs, process_group)
[rank3]   File "/usr/local/lib/python3.10/dist-packages/torch/distributed/_shard/sharded_tensor/_ops/tensor_ops.py", line 52, in tensor_device
[rank3]     dev = torch.device(torch.cuda.current_device())
[rank3]   File "/usr/local/lib/python3.10/dist-packages/torch/cuda/__init__.py", line 878, in current_device
[rank3]     _lazy_init()
[rank3]   File "/usr/local/lib/python3.10/dist-packages/torch/cuda/__init__.py", line 305, in _lazy_init
[rank3]     raise AssertionError("Torch not compiled with CUDA enabled")
[rank3] AssertionError: Torch not compiled with CUDA enabled
````

Pull Request resolved: pytorch#134994
Approved by: https://github.com/fegin
aostrowski-hbn pushed a commit to HabanaAI/pytorch-fork that referenced this pull request Jan 7, 2025
access tensor.device variable in right formart from ShardedTensor,DTensor and Tensor

PR : pytorch#134994

Change-Id: Id1ae919b8cd902899386ee756af680b872fee8c9
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ciflow/periodic Trigger jobs ran periodically on master (periodic.yml) on the PR ciflow/trunk Trigger trunk jobs on your pull request Merged oncall: distributed Add this issue/PR to distributed oncall triage queue open source release notes: distributed (fsdp) release notes category triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Projects

None yet

Development

Successfully merging this pull request may close these issues.

7 participants