KEMBAR78
[FSDP2] Added `shard_placement_fn` arg by awgu · Pull Request #137496 · pytorch/pytorch · GitHub
Skip to content

Conversation

@awgu
Copy link
Collaborator

@awgu awgu commented Oct 8, 2024

Stack from ghstack (oldest at bottom):

Overview

This PR adds a shard_placement_fn: Optional[Callable[[nn.Parameter], Optional[Shard]] arg to fully_shard that allows users to specify FSDP sharding on a nonzero tensor dim. If doing so, then the tensor dim size must be divisible by the FSDP shard world size.

# Example:
def shard_placement_fn(param: nn.Parameter) -> Optional[Shard]:
    largest_dim = largest_dim_size = -1
    for dim, dim_size in enumerate(param.shape):
        if dim_size > largest_dim_size:
            largest_dim = dim
            largest_dim_size = dim_size
    return Shard(largest_dim)

fully_shard(module, shard_placement_fn=shard_placement_fn)

Follow-Ups

  • Copy kernels: For all-gather copy-out, we currently copy-out to temporaries and then chunk-dim-0 -> cat-shard-dim, incurring an extra copy for parameters sharded on nonzero tensor dim. Similarly, for reduce-scatter copy-in, we currently chunk-shard-dim -> cat-dim-0, incurring an extra copy for gradients sharded on nonzero tensor dim. @yifuwang has ideas for adding additional split size args to the copy ops that allows fusing these extra copies into the existing all-gather copy-out and reduce-scatter copy-in.

cc @XilunWu @H-Huang @kwen2501 @wanchaol @fegin @fduwjj @wz337 @wconstab @d4l3k @c-p-i-o

@pytorch-bot
Copy link

pytorch-bot bot commented Oct 8, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/137496

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit 8fb9a36 with merge base d1b87e2 (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@pytorch-bot pytorch-bot bot added oncall: distributed Add this issue/PR to distributed oncall triage queue release notes: distributed (fsdp) release notes category labels Oct 8, 2024
awgu pushed a commit that referenced this pull request Oct 8, 2024
ghstack-source-id: 71ac63d
Pull Request resolved: #137496
@awgu awgu added release notes: distributed (fsdp2) release notes category and removed release notes: distributed (fsdp) release notes category labels Oct 8, 2024
@awgu awgu requested a review from weifengpy October 8, 2024 19:13
@awgu awgu marked this pull request as ready for review October 8, 2024 19:13
@awgu awgu requested a review from yifuwang October 8, 2024 19:13
Copy link
Contributor

@weifengpy weifengpy left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

epic work in such a short time! will sync with @yifuwang on copy kernels

@awgu awgu added the ciflow/trunk Trigger trunk jobs on your pull request label Oct 9, 2024
## Overview
This PR adds a `shard_placement_fn: Optional[Callable[[nn.Parameter], Optional[Shard]]` arg to `fully_shard` that allows users to specify FSDP sharding on a nonzero tensor dim. If doing so, then the tensor dim size must be divisible by the FSDP shard world size.

```
# Example:
def shard_placement_fn(param: nn.Parameter) -> Optional[Shard]:
    largest_dim = largest_dim_size = -1
    for dim, dim_size in enumerate(param.shape):
        if dim_size > largest_dim_size:
            largest_dim = dim
            largest_dim_size = dim_size
    return Shard(largest_dim)

fully_shard(module, shard_placement_fn=shard_placement_fn)
```


## Follow-Ups
- **Copy kernels:** For all-gather copy-out, we currently copy-out to temporaries and then chunk-dim-0 -> cat-shard-dim, incurring an extra copy for parameters sharded on nonzero tensor dim. Similarly, for reduce-scatter copy-in, we currently chunk-shard-dim -> cat-dim-0, incurring an extra copy for gradients sharded on nonzero tensor dim. yifuwang  has ideas for adding additional split size args to the copy ops that allows fusing these extra copies into the existing all-gather copy-out and reduce-scatter copy-in.

cc XilunWu H-Huang kwen2501 wanchaol fegin fduwjj wz337 wconstab d4l3k c-p-i-o

[ghstack-poisoned]
## Overview
This PR adds a `shard_placement_fn: Optional[Callable[[nn.Parameter], Optional[Shard]]` arg to `fully_shard` that allows users to specify FSDP sharding on a nonzero tensor dim. If doing so, then the tensor dim size must be divisible by the FSDP shard world size.

```
# Example:
def shard_placement_fn(param: nn.Parameter) -> Optional[Shard]:
    largest_dim = largest_dim_size = -1
    for dim, dim_size in enumerate(param.shape):
        if dim_size > largest_dim_size:
            largest_dim = dim
            largest_dim_size = dim_size
    return Shard(largest_dim)

fully_shard(module, shard_placement_fn=shard_placement_fn)
```


## Follow-Ups
- **Copy kernels:** For all-gather copy-out, we currently copy-out to temporaries and then chunk-dim-0 -> cat-shard-dim, incurring an extra copy for parameters sharded on nonzero tensor dim. Similarly, for reduce-scatter copy-in, we currently chunk-shard-dim -> cat-dim-0, incurring an extra copy for gradients sharded on nonzero tensor dim. yifuwang  has ideas for adding additional split size args to the copy ops that allows fusing these extra copies into the existing all-gather copy-out and reduce-scatter copy-in.

cc XilunWu H-Huang kwen2501 wanchaol fegin fduwjj wz337 wconstab d4l3k c-p-i-o

[ghstack-poisoned]
awgu pushed a commit that referenced this pull request Oct 9, 2024
ghstack-source-id: 7d4c314
Pull Request resolved: #137496
@awgu
Copy link
Collaborator Author

awgu commented Oct 9, 2024

@pytorchbot merge

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

@github-actions github-actions bot deleted the gh/awgu/649/head branch November 9, 2024 02:04
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ciflow/trunk Trigger trunk jobs on your pull request Merged oncall: distributed Add this issue/PR to distributed oncall triage queue release notes: distributed (fsdp2) release notes category

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants