-
Notifications
You must be signed in to change notification settings - Fork 25.7k
Use device-agnostic runtime API in distributed DDP/FSDP instead of cuda device specific.
#137678
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Use device-agnostic runtime API in distributed DDP/FSDP instead of cuda device specific.
#137678
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/137678
Note: Links to docs will display an error until the docs builds have been completed. ✅ You can merge normally! (1 Unrelated Failure)As of commit df87094 with merge base 034b105 ( FLAKY - The following job failed but was likely due to flakiness present on trunk:
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
torch/_utils.py
Outdated
|
|
||
| @functools.lru_cache(2) | ||
| def _get_device_module(device_type: str): | ||
| def _get_device_module(device_type: str = None): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| def _get_device_module(device_type: str = None): | |
| def _get_device_module(device_type: Optional[str] = None): |
Implicit optionals should not be used.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks, got you. Change done.
| # Make sure the allreduce will not conflict with any other ongoing process group. | ||
| if torch.cuda.is_available(): | ||
| torch.cuda.synchronize() | ||
| elif torch.xpu.is_available(): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: what about the case where CUDA and XPU devices are both available, but only one is in use? Shouldn't this be based on the parameters and not whether the backend is available? Or based on the distributed group in some way?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Currently, only one accelerator can be available at once on a given host (see https://github.com/pytorch/pytorch/blob/main/docs/source/torch.rst#accelerators). We can still make it more generic in this case.
I used the device type of flat_params (pass to allreduce later) to query the corresponding device module to perform synchronize() on cuda and xpu.
If you need any further adjustments, let me know!
| elif torch.xpu.is_available(): | |
| params_device_type = flat_params.device.type | |
| if params_device_type in ["cuda","xpu"]: | |
| _get_device_module(params_device_type).synchronize() |
|
@awgu @Skylion007 Thanks for your review comments and all should be addressed. Could you please help review again? |
cb096fd to
0975698
Compare
| ret_fut = torch.futures.Future() | ||
| stream = hook_state.upcast_stream | ||
| with torch.cuda.stream(stream): | ||
| with _get_device_module().stream(stream): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No corresponding API in torch.acc right now. So call _get_device_module() to get the device module.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why don't we have torch.acc.stream?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
As discussed in #132204 (comment), we will provide the support of with statement for torch.Stream as a context manager.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
OK, but I guess if the with stream is supported, we don't need to have to change _get_device_module and all such related changes in the PR, which sounds simpler for this PR? Is with stream change more complex?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, with stream: is simpler than with _get_device_module().stream(stream)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
So, does it make sense to support with stream to avoid complicating this PR?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is really nice discussion, thanks folks.
If we were to use get_device_module, I think it would be safer if we provide a device argument to it.
Previous:
torch.cuda <-- the device module is explicit
Current:
get_device_module() <-- assumes the user has set a "current device" context
Since user may not have done so in their program, I was just a little cautious whether the would cause a BC break.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oh nvm. get_device_module() would give priority to accelerator than cpu. When that's guaranteed, then the current code is safe :)
Source of torch.get_device_module:
elif device is None:
# Using default accelerator type. If no accelerator is available, it automatically returns CPU device.
|
I have refined this PR with |
cuda device specific.cuda device specific.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I like the PR. Thanks for the contribution. Had some minor questions.
Can you please sign the CLA? Thanks.
torch/_utils.py
Outdated
| if device_type is None: | ||
| device_type = torch._C._get_accelerator().type |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Comments added. Let me know if you want split to a separate PR.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I see that the needed functionality is now supported by:
torch.get_device_module(device=None)
Returns the module associated with a given device(e.g., torch.device(‘cuda’), “mtia:0”, “xpu”, …). If no device is given, return the module for the current accelerator or CPU if none is present.
https://pytorch.org/docs/stable/generated/torch.get_device_module.html
Maybe use that API?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Agreed. Remove change in _get_device_module(...) and use torch.get_device_module(device=None).
| # backward and set all DDP managed grads to None. | ||
| def wait_for_optim_stream_callback(): | ||
| torch.cuda.current_stream().wait_stream(optim_stream_state.optim_stream) | ||
| torch.acc.current_stream().wait_stream(optim_stream_state.optim_stream) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Curious: why do we sometimes use _get_device_module().stream and sometimes torch.acc.current_stream()?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Difference in API semantic - torch.cuda.current_stream() is to get a torch.Stream , while torch.cuda.stream(...) is to get StreamContext. https://github.com/pytorch/pytorch/blob/main/torch/cuda/__init__.py#L600
Currently, torch.acc doesn't provide StreamContext API, so I have to get the device module and then get stream context by _get_device_module().stream.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I asked in another hunk, why don't we have torch.acc.stream?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
As discussed in #132204 (comment), we will provide the support of with statement for torch.Stream as a context manager.
| if torch.cuda.is_available(): | ||
| torch.cuda.synchronize() | ||
| params_device_type = flat_params.device.type | ||
| if params_device_type in ["cuda", "xpu"]: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is there a device-agnostic way?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
synchronize should be common for each accelerator (at least no functionality impact). So change to
if torch.acc.is_available():
torch.acc.synchronize()
4584872 to
01b91c7
Compare
| # enqueue a callback to wait for this stream at end of backward | ||
| def wait_for_stream_cb(): | ||
| torch.cuda.current_stream().wait_stream(stream) | ||
| torch.acc.current_stream().wait_stream(stream) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| torch.acc.current_stream().wait_stream(stream) | |
| torch.accelerator.current_stream().wait_stream(stream) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we are getting close to land. Just two (edit: one) questions left around get_device_module
torch/_utils.py
Outdated
| if device_type is None: | ||
| device_type = torch._C._get_accelerator().type |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I see that the needed functionality is now supported by:
torch.get_device_module(device=None)
Returns the module associated with a given device(e.g., torch.device(‘cuda’), “mtia:0”, “xpu”, …). If no device is given, return the module for the current accelerator or CPU if none is present.
https://pytorch.org/docs/stable/generated/torch.get_device_module.html
Maybe use that API?
| ret_fut = torch.futures.Future() | ||
| stream = hook_state.upcast_stream | ||
| with torch.cuda.stream(stream): | ||
| with _get_device_module().stream(stream): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is really nice discussion, thanks folks.
If we were to use get_device_module, I think it would be safer if we provide a device argument to it.
Previous:
torch.cuda <-- the device module is explicit
Current:
get_device_module() <-- assumes the user has set a "current device" context
Since user may not have done so in their program, I was just a little cautious whether the would cause a BC break.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM. Maybe consider using torch.get_device_module(...)? It is more formal.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please fix lint error.
|
@pytorchbot rebase |
|
@pytorchbot started a rebase job onto refs/remotes/origin/viable/strict. Check the current status here |
|
Successfully rebased |
2268767 to
bafe615
Compare
Thanks. |
| import torch.distributed as dist | ||
| from torch.autograd import Variable | ||
| from torch.distributed.utils import _free_storage | ||
| from torch._utils import _get_device_module |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this import is no longer needed
| import torch.distributed as dist | ||
| from torch.autograd import Variable | ||
|
|
||
| from torch._utils import _get_device_module |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this import is no longer needed.
|
@pytorchbot rebase |
|
You don't have permissions to rebase this PR since you are a first time contributor. If you think this is a mistake, please contact PyTorch Dev Infra. |
|
@pytorchbot rebase |
|
@pytorchbot started a rebase job onto refs/remotes/origin/viable/strict. Check the current status here |
|
Successfully rebased |
03559d3 to
df87094
Compare
|
"Unrelated failures" |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
…uda` device specific. (pytorch#137678) # Motivation This PR targets to use device-agnostic runtime API in distributed DDP/FSDP instead of `cuda` device specific. cc cc [@jgong5](https://github.com/jgong5) [@gujinghui](https://github.com/gujinghui) [@EikanWang](https://github.com/EikanWang) [@fengyuan14](https://github.com/fengyuan14) [@guangyey](https://github.com/guangyey) Pull Request resolved: pytorch#137678 Approved by: https://github.com/kwen2501, https://github.com/guangyey, https://github.com/jgong5
# Motivation In #137678, we help use the device-agnostic APIs to generalize distributed module. As this [comment](#137678 (comment)) said, we will use the with statement of `torch.Stream` once #140138 is landed. Pull Request resolved: #144951 Approved by: https://github.com/kwen2501, https://github.com/albanD
Motivation
This PR targets to use device-agnostic runtime API in distributed DDP/FSDP instead of
cudadevice specific.cc cc @jgong5 @gujinghui @EikanWang @fengyuan14 @guangyey
cc @XilunWu @H-Huang @awgu @kwen2501 @wanchaol @fegin @fduwjj @wz337 @wconstab @d4l3k @c-p-i-o