KEMBAR78
[FSDP2] Fixed incorrect tensor meta after `.to(dtype)` by awgu · Pull Request #137593 · pytorch/pytorch · GitHub
Skip to content

Conversation

@awgu
Copy link
Collaborator

@awgu awgu commented Oct 9, 2024

Stack from ghstack (oldest at bottom):

This fixes #137522. After a method that changes to module parameters (like .to(torch.float64)), we need to update the DTensorSpec, whose TensorMeta's dtype may have changed.

cc @XilunWu @H-Huang @kwen2501 @wanchaol @fegin @fduwjj @wz337 @wconstab @d4l3k @c-p-i-o

@pytorch-bot
Copy link

pytorch-bot bot commented Oct 9, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/137593

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit 28daaa3 with merge base d1b87e2 (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

awgu pushed a commit that referenced this pull request Oct 9, 2024
ghstack-source-id: d922192
Pull Request resolved: #137593
@pytorch-bot pytorch-bot bot added oncall: distributed Add this issue/PR to distributed oncall triage queue release notes: distributed (fsdp) release notes category labels Oct 9, 2024
@awgu awgu added release notes: distributed (fsdp2) release notes category and removed release notes: distributed (fsdp) release notes category labels Oct 9, 2024
@awgu awgu marked this pull request as ready for review October 9, 2024 15:16
@awgu awgu requested a review from weifengpy October 9, 2024 15:16
@awgu awgu added the ciflow/trunk Trigger trunk jobs on your pull request label Oct 9, 2024
if updated_local_tensor:
# Only change the local tensor object if needed
self.sharded_param._local_tensor = local_tensor[: self.sharded_size[0]]
self._sharding_spec = self.sharded_param._spec
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Question. instead of cashing self._sharding_spec would it make sense to have it be a property that always just queries self.sharded_param._spec

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That sounds good!

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

took a quick look -- this will require some refactoring, so I will defer this to a separate PR

@awgu
Copy link
Collaborator Author

awgu commented Oct 9, 2024

@pytorchbot merge

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

pytorchmergebot pushed a commit that referenced this pull request Oct 9, 2024
## Overview
This PR adds a `shard_placement_fn: Optional[Callable[[nn.Parameter], Optional[Shard]]` arg to `fully_shard` that allows users to specify FSDP sharding on a nonzero tensor dim. If doing so, then the tensor dim size must be divisible by the FSDP shard world size.

```
# Example:
def shard_placement_fn(param: nn.Parameter) -> Optional[Shard]:
    largest_dim = largest_dim_size = -1
    for dim, dim_size in enumerate(param.shape):
        if dim_size > largest_dim_size:
            largest_dim = dim
            largest_dim_size = dim_size
    return Shard(largest_dim)

fully_shard(module, shard_placement_fn=shard_placement_fn)
```

## Follow-Ups
- **Copy kernels:** For all-gather copy-out, we currently copy-out to temporaries and then chunk-dim-0 -> cat-shard-dim, incurring an extra copy for parameters sharded on nonzero tensor dim. Similarly, for reduce-scatter copy-in, we currently chunk-shard-dim -> cat-dim-0, incurring an extra copy for gradients sharded on nonzero tensor dim. @yifuwang  has ideas for adding additional split size args to the copy ops that allows fusing these extra copies into the existing all-gather copy-out and reduce-scatter copy-in.

Pull Request resolved: #137496
Approved by: https://github.com/weifengpy
ghstack dependencies: #137593
@github-actions github-actions bot deleted the gh/awgu/650/head branch November 9, 2024 02:02
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ciflow/trunk Trigger trunk jobs on your pull request Merged oncall: distributed Add this issue/PR to distributed oncall triage queue release notes: distributed (fsdp2) release notes category

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants