KEMBAR78
Use device-agnostic runtime API in distributed DDP/FSDP instead of `cuda` device specific. by zhangxiaoli73 · Pull Request #137678 · pytorch/pytorch · GitHub
Skip to content

Conversation

@zhangxiaoli73
Copy link
Contributor

@zhangxiaoli73 zhangxiaoli73 commented Oct 10, 2024

Motivation

This PR targets to use device-agnostic runtime API in distributed DDP/FSDP instead of cuda device specific.

cc cc @jgong5 @gujinghui @EikanWang @fengyuan14 @guangyey

cc @XilunWu @H-Huang @awgu @kwen2501 @wanchaol @fegin @fduwjj @wz337 @wconstab @d4l3k @c-p-i-o

@pytorch-bot
Copy link

pytorch-bot bot commented Oct 10, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/137678

Note: Links to docs will display an error until the docs builds have been completed.

✅ You can merge normally! (1 Unrelated Failure)

As of commit df87094 with merge base 034b105 (image):

FLAKY - The following job failed but was likely due to flakiness present on trunk:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@pytorch-bot pytorch-bot bot added oncall: distributed Add this issue/PR to distributed oncall triage queue release notes: distributed (ddp) release notes category labels Oct 10, 2024
@linux-foundation-easycla
Copy link

linux-foundation-easycla bot commented Oct 10, 2024

CLA Signed

The committers listed above are authorized under a signed CLA.

torch/_utils.py Outdated

@functools.lru_cache(2)
def _get_device_module(device_type: str):
def _get_device_module(device_type: str = None):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
def _get_device_module(device_type: str = None):
def _get_device_module(device_type: Optional[str] = None):

Implicit optionals should not be used.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, got you. Change done.

# Make sure the allreduce will not conflict with any other ongoing process group.
if torch.cuda.is_available():
torch.cuda.synchronize()
elif torch.xpu.is_available():
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: what about the case where CUDA and XPU devices are both available, but only one is in use? Shouldn't this be based on the parameters and not whether the backend is available? Or based on the distributed group in some way?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Currently, only one accelerator can be available at once on a given host (see https://github.com/pytorch/pytorch/blob/main/docs/source/torch.rst#accelerators). We can still make it more generic in this case.

I used the device type of flat_params (pass to allreduce later) to query the corresponding device module to perform synchronize() on cuda and xpu.

If you need any further adjustments, let me know!

Suggested change
elif torch.xpu.is_available():
params_device_type = flat_params.device.type
if params_device_type in ["cuda","xpu"]:
_get_device_module(params_device_type).synchronize()

@soulitzer soulitzer requested a review from awgu October 10, 2024 14:46
@soulitzer soulitzer added the triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module label Oct 10, 2024
@guangyey guangyey added the ciflow/xpu Run XPU CI tasks label Oct 11, 2024
@zhangxiaoli73
Copy link
Contributor Author

@awgu @Skylion007 Thanks for your review comments and all should be addressed. Could you please help review again?

@zhangxiaoli73 zhangxiaoli73 force-pushed the cherry/distributed-frontend branch from cb096fd to 0975698 Compare November 1, 2024 08:56
ret_fut = torch.futures.Future()
stream = hook_state.upcast_stream
with torch.cuda.stream(stream):
with _get_device_module().stream(stream):
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No corresponding API in torch.acc right now. So call _get_device_module() to get the device module.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why don't we have torch.acc.stream?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As discussed in #132204 (comment), we will provide the support of with statement for torch.Stream as a context manager.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK, but I guess if the with stream is supported, we don't need to have to change _get_device_module and all such related changes in the PR, which sounds simpler for this PR? Is with stream change more complex?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, with stream: is simpler than with _get_device_module().stream(stream)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So, does it make sense to support with stream to avoid complicating this PR?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@jgong5 it seems not so easy to support with stream in this PR. @guangyey has a WIP PR to provide with stream
#140138 with calling some new accelerator APIs.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is really nice discussion, thanks folks.

If we were to use get_device_module, I think it would be safer if we provide a device argument to it.
Previous:
torch.cuda <-- the device module is explicit
Current:
get_device_module() <-- assumes the user has set a "current device" context

Since user may not have done so in their program, I was just a little cautious whether the would cause a BC break.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh nvm. get_device_module() would give priority to accelerator than cpu. When that's guaranteed, then the current code is safe :)
Source of torch.get_device_module:

    elif device is None:
        # Using default accelerator type. If no accelerator is available, it automatically returns CPU device.

@zhangxiaoli73
Copy link
Contributor Author

I have refined this PR with torch.acc which offers device-agnostic runtime APIs. @awgu @Skylion007 Could you please help review again?

@zhangxiaoli73 zhangxiaoli73 changed the title Use detected device module in distributed DDP/FSDP instead of cuda device specific. Use device-agnostic runtime API in distributed DDP/FSDP instead of cuda device specific. Nov 1, 2024
Copy link
Contributor

@kwen2501 kwen2501 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I like the PR. Thanks for the contribution. Had some minor questions.
Can you please sign the CLA? Thanks.

torch/_utils.py Outdated
Comment on lines 926 to 929
if device_type is None:
device_type = torch._C._get_accelerator().type
Copy link
Contributor

@kwen2501 kwen2501 Nov 2, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Some comment would be appreciated here.
cc @albanD @janeyx99 to review this change.
Also, considering the significance of this change itself, does it make sense to put this change into a separate, base PR?

Copy link
Contributor Author

@zhangxiaoli73 zhangxiaoli73 Nov 4, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Comments added. Let me know if you want split to a separate PR.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see that the needed functionality is now supported by:

torch.get_device_module(device=None)

Returns the module associated with a given device(e.g., torch.device(‘cuda’), “mtia:0”, “xpu”, …). If no device is given, return the module for the current accelerator or CPU if none is present.

https://pytorch.org/docs/stable/generated/torch.get_device_module.html

Maybe use that API?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agreed. Remove change in _get_device_module(...) and use torch.get_device_module(device=None).

# backward and set all DDP managed grads to None.
def wait_for_optim_stream_callback():
torch.cuda.current_stream().wait_stream(optim_stream_state.optim_stream)
torch.acc.current_stream().wait_stream(optim_stream_state.optim_stream)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Curious: why do we sometimes use _get_device_module().stream and sometimes torch.acc.current_stream()?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Difference in API semantic - torch.cuda.current_stream() is to get a torch.Stream , while torch.cuda.stream(...) is to get StreamContext. https://github.com/pytorch/pytorch/blob/main/torch/cuda/__init__.py#L600

Currently, torch.acc doesn't provide StreamContext API, so I have to get the device module and then get stream context by _get_device_module().stream.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I asked in another hunk, why don't we have torch.acc.stream?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@jgong5 torch.acc.stream is not ready right now and will be supported by @guangyey

Copy link
Collaborator

@guangyey guangyey Nov 5, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As discussed in #132204 (comment), we will provide the support of with statement for torch.Stream as a context manager.

if torch.cuda.is_available():
torch.cuda.synchronize()
params_device_type = flat_params.device.type
if params_device_type in ["cuda", "xpu"]:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there a device-agnostic way?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

synchronize should be common for each accelerator (at least no functionality impact). So change to

if torch.acc.is_available():
    torch.acc.synchronize()

# enqueue a callback to wait for this stream at end of backward
def wait_for_stream_cb():
torch.cuda.current_stream().wait_stream(stream)
torch.acc.current_stream().wait_stream(stream)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
torch.acc.current_stream().wait_stream(stream)
torch.accelerator.current_stream().wait_stream(stream)

Copy link
Contributor

@kwen2501 kwen2501 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we are getting close to land. Just two (edit: one) questions left around get_device_module

torch/_utils.py Outdated
Comment on lines 926 to 929
if device_type is None:
device_type = torch._C._get_accelerator().type
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see that the needed functionality is now supported by:

torch.get_device_module(device=None)

Returns the module associated with a given device(e.g., torch.device(‘cuda’), “mtia:0”, “xpu”, …). If no device is given, return the module for the current accelerator or CPU if none is present.

https://pytorch.org/docs/stable/generated/torch.get_device_module.html

Maybe use that API?

ret_fut = torch.futures.Future()
stream = hook_state.upcast_stream
with torch.cuda.stream(stream):
with _get_device_module().stream(stream):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is really nice discussion, thanks folks.

If we were to use get_device_module, I think it would be safer if we provide a device argument to it.
Previous:
torch.cuda <-- the device module is explicit
Current:
get_device_module() <-- assumes the user has set a "current device" context

Since user may not have done so in their program, I was just a little cautious whether the would cause a BC break.

Copy link
Contributor

@kwen2501 kwen2501 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM. Maybe consider using torch.get_device_module(...)? It is more formal.

Copy link
Collaborator

@guangyey guangyey left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please fix lint error.

@guangyey
Copy link
Collaborator

@pytorchbot rebase

@pytorchmergebot
Copy link
Collaborator

@pytorchbot started a rebase job onto refs/remotes/origin/viable/strict. Check the current status here

@pytorchmergebot
Copy link
Collaborator

Successfully rebased cherry/distributed-frontend onto refs/remotes/origin/viable/strict, please pull locally before adding more changes (for example, via git checkout cherry/distributed-frontend && git pull --rebase)

@pytorchmergebot pytorchmergebot force-pushed the cherry/distributed-frontend branch from 2268767 to bafe615 Compare November 11, 2024 01:54
@zhangxiaoli73
Copy link
Contributor Author

LGTM. Maybe consider using torch.get_device_module(...)? It is more formal.

Thanks. torch.get_device_module(...) looks more formal. Let me use this API.

import torch.distributed as dist
from torch.autograd import Variable
from torch.distributed.utils import _free_storage
from torch._utils import _get_device_module

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this import is no longer needed

import torch.distributed as dist
from torch.autograd import Variable

from torch._utils import _get_device_module

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this import is no longer needed.

@guangyey guangyey added the ciflow/trunk Trigger trunk jobs on your pull request label Nov 12, 2024
@zhangxiaoli73
Copy link
Contributor Author

@pytorchbot rebase

@pytorch-bot
Copy link

pytorch-bot bot commented Nov 13, 2024

You don't have permissions to rebase this PR since you are a first time contributor. If you think this is a mistake, please contact PyTorch Dev Infra.

@guangyey
Copy link
Collaborator

@pytorchbot rebase

@pytorchmergebot
Copy link
Collaborator

@pytorchbot started a rebase job onto refs/remotes/origin/viable/strict. Check the current status here

@pytorchmergebot
Copy link
Collaborator

Successfully rebased cherry/distributed-frontend onto refs/remotes/origin/viable/strict, please pull locally before adding more changes (for example, via git checkout cherry/distributed-frontend && git pull --rebase)

@pytorchmergebot pytorchmergebot force-pushed the cherry/distributed-frontend branch from 03559d3 to df87094 Compare November 13, 2024 01:58
@guangyey
Copy link
Collaborator

"Unrelated failures"
@pytorchbot merge

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

pobin6 pushed a commit to pobin6/pytorch that referenced this pull request Dec 5, 2024
pytorchmergebot pushed a commit that referenced this pull request Jan 17, 2025
# Motivation
In #137678, we help use the device-agnostic APIs to generalize distributed module. As this [comment](#137678 (comment)) said, we will use the with statement of `torch.Stream` once #140138 is landed.

Pull Request resolved: #144951
Approved by: https://github.com/kwen2501, https://github.com/albanD
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ciflow/trunk Trigger trunk jobs on your pull request ciflow/xpu Run XPU CI tasks Merged oncall: distributed Add this issue/PR to distributed oncall triage queue open source release notes: distributed (ddp) release notes category triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Projects

Status: Done

Development

Successfully merging this pull request may close these issues.

9 participants