KEMBAR78
API to retrieve default distributed backend from device by ankurneog · Pull Request #140536 · pytorch/pytorch · GitHub
Skip to content

Conversation

@ankurneog
Copy link

@ankurneog ankurneog commented Nov 13, 2024

Motivation

The distributed APIs rely on backend names for creation of process group.
To abstract out references of these names from PG creation, an API is added to get default distributed backend for device.
The device code would need to register its device and backend via torch.distributed.Backend.register_backend or update the map torch.distributed.Backend.default_device_backend_map["device"] = "distributed_backend" prior to using the API.

An example of use is added in the test file ( which can be used to check abstracted APIs)

cc @H-Huang @awgu @kwen2501 @wanchaol @fegin @fduwjj @wz337 @wconstab @d4l3k @c-p-i-o

@pytorch-bot
Copy link

pytorch-bot bot commented Nov 13, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/140536

Note: Links to docs will display an error until the docs builds have been completed.

❗ 1 Active SEVs

There are 1 currently active SEVs. If your PR is affected, please view them below:

✅ No Failures

As of commit 7741dfa with merge base 740d1eb (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@pytorch-bot pytorch-bot bot added oncall: distributed Add this issue/PR to distributed oncall triage queue release notes: distributed (c10d) release notes category labels Nov 13, 2024
@ankurneog
Copy link
Author

@kwen2501 : can you please help with the review. thanks

@ezyang ezyang requested a review from kwen2501 November 14, 2024 03:22
@ezyang ezyang added the triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module label Nov 14, 2024
@ankurneog
Copy link
Author

@pytorchbot rebase

@ankurneog ankurneog requested a review from guangyey November 18, 2024 04:05
@pytorchmergebot
Copy link
Collaborator

@pytorchbot started a rebase job onto refs/remotes/origin/viable/strict. Check the current status here

@pytorchmergebot
Copy link
Collaborator

Successfully rebased distributed_api onto refs/remotes/origin/viable/strict, please pull locally before adding more changes (for example, via git checkout distributed_api && git pull --rebase)

@ankurneog ankurneog changed the title API to retrieve default backend from device API to retrieve default distributed backend from device Nov 19, 2024
@ankurneog
Copy link
Author

@kwen2501 : can you please help with the review and approval.

@kwen2501
Copy link
Contributor

On it, sorry

@kwen2501
Copy link
Contributor

kwen2501 commented Nov 19, 2024

Overall looks good to me.
Related, we also want to waive the need for backend specification when calling init_process_group. Here is a PR enabling that: #140963.

To follow up, we can think of a way to register "hpu": "hccl" mapping with torch c10d. wdyt?
To kick start, it would be nice if you could share 1) where ProcessGroupHCCL is packaged today; and 2) how your user import that package. Ideally, we want the registration to happen automatically during the import so that user does not need to get involved.

This is the UX we want to get to:

device = torch.Device("hpu", rank % device_count)
dist.init_process_group(device_id=device)

@kwen2501
Copy link
Contributor

kwen2501 commented Nov 19, 2024

One easy way to do registration is to add an entry here:

# 3rd-party devices can register the default backend support here
default_device_backend_map: Dict[str, str] = {
"cpu": GLOO,
"cuda": NCCL,
}

That is, in your package's __init__.py, you can add a line like this:

torch.distributed.Backend.default_device_backend_map["hpu"] = "hccl"

We can later add a formal registration API too. wdyt?

Copy link
Contributor

@kwen2501 kwen2501 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Overall lgtm. Please see comments above on how this can be integrated with new c10d capability.

@ankurneog ankurneog force-pushed the distributed_api branch 2 times, most recently from 3e02e2b to c4e7727 Compare November 20, 2024 04:13
@ankurneog
Copy link
Author

ankurneog commented Nov 20, 2024

Thanks @kwen2501 for your comment, your change with #140963 will be helpful.

Regarding your question on HPU registration, it is done by calling the register_backend API as follows :
torch.distributed.Backend.register_backend("hccl", _create_process_group_hccl, devices=["hpu"], extended_api=True)

register_backend ensures that the mapping for hpu to hccl is done :
Backend.default_device_backend_map[device] = name.lower()

I have modified the code accordingly to get the backend string directly using:
Backend.default_device_backend_map.get(device_str)

Let me know your views.

@ankurneog
Copy link
Author

@kwen2501 : can you please help with the approval, thanks

Copy link
Contributor

@kwen2501 kwen2501 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM.

@guangyey
Copy link
Collaborator

@pytorchbot merge

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

@pytorchmergebot
Copy link
Collaborator

Merge failed

Reason: 1 mandatory check(s) failed. The first few are:

Dig deeper by viewing the failures on hud

Details for Dev Infra team Raised by workflow job

Failing merge rule: Core Maintainers

@guangyey
Copy link
Collaborator

@pytorchbot rebase

@pytorchmergebot
Copy link
Collaborator

@pytorchbot started a rebase job onto refs/remotes/origin/viable/strict. Check the current status here

@pytorchmergebot
Copy link
Collaborator

Successfully rebased distributed_api onto refs/remotes/origin/viable/strict, please pull locally before adding more changes (for example, via git checkout distributed_api && git pull --rebase)

@ankurneog
Copy link
Author

@pytorchbot merge

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

@ankurneog ankurneog deleted the distributed_api branch November 22, 2024 11:03
pobin6 pushed a commit to pobin6/pytorch that referenced this pull request Dec 5, 2024
# Motivation
The distributed APIs rely on backend names for creation of process group.
To abstract out references of these names from PG creation, an API is added to get default distributed backend for  device.
The device code would need to register its device and backend  via  ```torch.distributed.Backend.register_backend```  or  update the map ``` torch.distributed.Backend.default_device_backend_map["device"] = "distributed_backend" ```  prior to using the API.

An example of use is added in the test file ( which can be used to check abstracted APIs)

Pull Request resolved: pytorch#140536
Approved by: https://github.com/kwen2501
pytorchmergebot pushed a commit that referenced this pull request Feb 5, 2025
In this series of PR we intend to refactoring distributed test cases to enable to be completely device agnostic.

These changes will include the following approaches to do the same :

- Allowing for multiple device types using instantiate_device_type_test
- Replacing calls to cuda stream with torch.get_device_module(device) wherever it applies
- Skipping set up steps required while using MultiProcessTestCase with DistributedTestBase (#138216) wherever applicable
- Replacing explicit calls to distributed backend (NCCL,HCCL,etc) with get_default_backend_for_device (#140536).

This should result in significant improvement in usability for all devices

Pull Request resolved: #145222
Approved by: https://github.com/kwen2501
mori360 pushed a commit to mori360/pytorch that referenced this pull request Feb 6, 2025
…ch#145222)

In this series of PR we intend to refactoring distributed test cases to enable to be completely device agnostic.

These changes will include the following approaches to do the same :

- Allowing for multiple device types using instantiate_device_type_test
- Replacing calls to cuda stream with torch.get_device_module(device) wherever it applies
- Skipping set up steps required while using MultiProcessTestCase with DistributedTestBase (pytorch#138216) wherever applicable
- Replacing explicit calls to distributed backend (NCCL,HCCL,etc) with get_default_backend_for_device (pytorch#140536).

This should result in significant improvement in usability for all devices

Pull Request resolved: pytorch#145222
Approved by: https://github.com/kwen2501
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ciflow/trunk Trigger trunk jobs on your pull request Merged oncall: distributed Add this issue/PR to distributed oncall triage queue open source release notes: distributed (c10d) release notes category triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Projects

None yet

Development

Successfully merging this pull request may close these issues.

6 participants