-
Notifications
You must be signed in to change notification settings - Fork 25.7k
[dtensor] full_tensor to return synchronously #113322
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
full_tensor API should return synchronously instead of AsyncCollectiveTensor and if the return is that, we do the wait directly, this makes the full_tensor API be more percise [ghstack-poisoned]
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/113322
Note: Links to docs will display an error until the docs builds have been completed. ✅ No FailuresAs of commit e098488 with merge base 84d64d7 ( This comment was automatically generated by Dr. CI and updates every 15 minutes. |
full_tensor API should return synchronously instead of AsyncCollectiveTensor and if the return is that, we do the wait directly, this makes the full_tensor API be more percise [ghstack-poisoned]
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM!
|
@pytorchbot merge |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
full_tensor API should return synchronously instead of AsyncCollectiveTensor and if the return is that, we do the wait directly, this makes the full_tensor API be more percise Pull Request resolved: pytorch#113322 Approved by: https://github.com/wz337
| def forward( # type: ignore[override] | ||
| ctx, | ||
| input: "DTensor", | ||
| grad_placements: Optional[Sequence[Placement]], | ||
| async_output: bool, | ||
| ): | ||
| ctx.dtensor_spec = input._spec | ||
| ctx.grad_placements = grad_placements | ||
| local_tensor = input._local_tensor | ||
| if not async_output and isinstance(local_tensor, funcol.AsyncCollectiveTensor): | ||
| # synchronously wait for any pending collectives to get the result tensor | ||
| local_tensor = local_tensor.trigger_wait() | ||
| local_tensor = local_tensor.elem # type: ignore[attr-defined] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't have a strong opinion on this, but this might bring some perf regression because for every .to_local we are waiting. But if we think this is the defined or expected behavior, then that's fine.
Stack from ghstack (oldest at bottom):
full_tensor API should return synchronously instead of
AsyncCollectiveTensor and if the return is that, we do the wait
directly, this makes the full_tensor API be more percise