-
Notifications
You must be signed in to change notification settings - Fork 25.7k
[FSDP2] Fixed incorrect tensor meta after .to(dtype)
#137593
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
[ghstack-poisoned]
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/137593
Note: Links to docs will display an error until the docs builds have been completed. ✅ No FailuresAs of commit 28daaa3 with merge base d1b87e2 ( This comment was automatically generated by Dr. CI and updates every 15 minutes. |
| if updated_local_tensor: | ||
| # Only change the local tensor object if needed | ||
| self.sharded_param._local_tensor = local_tensor[: self.sharded_size[0]] | ||
| self._sharding_spec = self.sharded_param._spec |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Question. instead of cashing self._sharding_spec would it make sense to have it be a property that always just queries self.sharded_param._spec
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That sounds good!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
took a quick look -- this will require some refactoring, so I will defer this to a separate PR
|
@pytorchbot merge |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
## Overview
This PR adds a `shard_placement_fn: Optional[Callable[[nn.Parameter], Optional[Shard]]` arg to `fully_shard` that allows users to specify FSDP sharding on a nonzero tensor dim. If doing so, then the tensor dim size must be divisible by the FSDP shard world size.
```
# Example:
def shard_placement_fn(param: nn.Parameter) -> Optional[Shard]:
largest_dim = largest_dim_size = -1
for dim, dim_size in enumerate(param.shape):
if dim_size > largest_dim_size:
largest_dim = dim
largest_dim_size = dim_size
return Shard(largest_dim)
fully_shard(module, shard_placement_fn=shard_placement_fn)
```
## Follow-Ups
- **Copy kernels:** For all-gather copy-out, we currently copy-out to temporaries and then chunk-dim-0 -> cat-shard-dim, incurring an extra copy for parameters sharded on nonzero tensor dim. Similarly, for reduce-scatter copy-in, we currently chunk-shard-dim -> cat-dim-0, incurring an extra copy for gradients sharded on nonzero tensor dim. @yifuwang has ideas for adding additional split size args to the copy ops that allows fusing these extra copies into the existing all-gather copy-out and reduce-scatter copy-in.
Pull Request resolved: #137496
Approved by: https://github.com/weifengpy
ghstack dependencies: #137593
Stack from ghstack (oldest at bottom):
shard_placement_fnarg #137496.to(dtype)#137593This fixes #137522. After a method that changes to module parameters (like
.to(torch.float64)), we need to update theDTensorSpec, whoseTensorMeta's dtype may have changed.cc @XilunWu @H-Huang @kwen2501 @wanchaol @fegin @fduwjj @wz337 @wconstab @d4l3k @c-p-i-o