KEMBAR78
[DTensor] Used new placements for neg dim in `from_local` by awgu · Pull Request #114134 · pytorch/pytorch · GitHub
Skip to content

Conversation

@pytorch-bot
Copy link

pytorch-bot bot commented Nov 20, 2023

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/114134

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit 43a9a1f with merge base 140c54e (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

placements = list(placements)
for idx, placement in enumerate(placements):
# normalize shard dim to be positive
if placement.is_shard():
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@wanchaol Should we converge to using placement.is_shard() or to using isinstance(placement, Shard)? The former calls the latter but allows for passing a dim arg to further check against, and the latter avoids having to use cast(Shard, placement).

It seems like is_shard() is a higher level construct and should be preferred, but I wanted to check.

Copy link
Collaborator

@wanchaol wanchaol Nov 20, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

that's exactly the trade off you pointed out lol, I would like to use the former uniformly, but mypy can't recognize it as a result there're many redundant cast needed if we switch all callsite to that..

I think we can use either of them when we feel one is more easy to use. Maybe we can do this in the meanwhile:

  • isinstance(placement, Shard) preferred if do simple type check
  • is_shard(dim) where dim become non-optional, so that this API only used as a util to check if the placement is shard on a certain tensor dim

placement = cast(Shard, placement)
if placement.dim < 0:
placements[idx] = Shard(placement.dim + local_tensor.ndim)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The conversion of placements to tuple is below:

return _FromTorchTensor.apply( # pyre-ignore[16]: autograd func
local_tensor,
device_mesh,
tuple(placements),
run_check,
shape,
stride,
)

@awgu awgu marked this pull request as ready for review November 20, 2023 17:34
Copy link
Collaborator

@wanchaol wanchaol left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cool!

@awgu awgu added ciflow/trunk Trigger trunk jobs on your pull request release notes: distributed (dtensor) release notes category labels Nov 20, 2023
pytorchmergebot pushed a commit that referenced this pull request Nov 20, 2023
pytorchmergebot pushed a commit that referenced this pull request Nov 20, 2023
pytorchmergebot pushed a commit that referenced this pull request Nov 21, 2023
This is a replacement for #113922. I think we can still leave the check for negative shard dimension in `compute_local_shape_and_global_offset` and replace the normalization logic with an assert. This should provide us a stack trace to see which user-facing API did not normalize the dim as expected.
Pull Request resolved: #114141
Approved by: https://github.com/wanchaol
ghstack dependencies: #113919, #113924, #114134, #113925, #113930
pytorchmergebot pushed a commit that referenced this pull request Nov 21, 2023
**Overview**
Generally, I think we can try to freeze as many of these classes used in DTensor sharding propagation as possible so that we can cache hashes. This PR targets hashing `DTensorSpec`, which turns out to be relatively expensive.

**Details**
It looks like `tensor_meta` is only updated in `_wrap_output_spec_tensor_meta`, which only runs if the propagation was not cached:
https://github.com/pytorch/pytorch/blob/ae94c7e491e22f58d3df66571c1a568e51d70acd/torch/distributed/_tensor/sharding_prop.py#L137
https://github.com/pytorch/pytorch/blob/ae94c7e491e22f58d3df66571c1a568e51d70acd/torch/distributed/_tensor/sharding_prop.py#L153
In that case, I think we can cache the hash for the `DTensorSpec` and only update it when one of the hashed attributes changes, which we only really expect to happen for `tensor_meta`.

To ensure correctness, we need that all hashed attributes are immutable.
- `DeviceMesh` caches its hash: https://github.com/pytorch/pytorch/blob/a9134fa99a8986adf478a12db2ea5729d24554db/torch/distributed/_device_mesh.py#L181
- This PR makes each `Placement` a frozen `dataclass`, making them immutable (relying on the fact that they do not have references to any mutable objects).
- `TensorMeta` is a `NamedTuple` of `torch.Size`, `Tuple[int, ...]`, and `torch.dtype`, so it is immutable: https://github.com/pytorch/pytorch/blob/9916d8a9eaaf2c05c131f2a2dbe9eabeeaa9dffc/torch/distributed/_tensor/placement_types.py#L369-L375

**Example**
For some simple small GPT model:
Before: 0.125 ms
<img width="509" alt="Screenshot 2023-11-16 at 10 08 05 PM" src="https://github.com/pytorch/pytorch/assets/31054793/10e59401-f635-431f-80b5-1b48df3a706e">

After: 0.048 ms
<img width="294" alt="Screenshot 2023-11-16 at 10 08 47 PM" src="https://github.com/pytorch/pytorch/assets/31054793/09a3b0b9-f68c-4afc-bca1-c29a4b01c2fb">

The overall Adam CPU step time decreases from 7.647 ms to 6.451 ms.
Pull Request resolved: #113915
Approved by: https://github.com/wanchaol
ghstack dependencies: #113919, #113924, #114134, #113925, #113930, #114141
pytorchmergebot pushed a commit that referenced this pull request Nov 21, 2023
This is a nit change to save one `isinstance` call for when `dim` is not `None` but the placement is not `Shard`.
Pull Request resolved: #114140
Approved by: https://github.com/Skylion007, https://github.com/wanchaol
ghstack dependencies: #113919, #113924, #114134, #113925, #113930, #114141, #113915
pytorchmergebot pushed a commit that referenced this pull request Nov 22, 2023
This is a forward fix for #113781.

We lazily compute the hash so that we do not try to compute the hash on `SymInt`s (for the stride) during Dynamo tracing.

Tested via:
```
python test/distributed/_tensor/test_dtensor_compile.py -k test_2d_fsdp_tp_ac_compile
```
Pull Request resolved: #114322
Approved by: https://github.com/wanchaol
ghstack dependencies: #113919, #113924, #114134, #113925, #113930, #114141, #113915, #114140
@facebook-github-bot facebook-github-bot deleted the gh/awgu/462/head branch November 24, 2023 15:27
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ciflow/trunk Trigger trunk jobs on your pull request Merged release notes: distributed (dtensor) release notes category

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants