KEMBAR78
[c10d] Support optional backend if device_id provided by kwen2501 · Pull Request #140963 · pytorch/pytorch · GitHub
Skip to content

Conversation

@kwen2501
Copy link
Contributor

@kwen2501 kwen2501 commented Nov 18, 2024

Stack from ghstack (oldest at bottom):

Citing @malfet's comment in #136343

It would be great, if users do not have to modify their programs for every new backend, but rather use with torch.device('xpu'): and keep rest of the code unchanged.

This PR makes the backend specification ("nccl", "gloo") optional when user provides a devce_id to init_process_group (the acceptance of device_id has been previously supported for the purpose of eager init).

New user experience:

device = torch.device(device_type, rank % device_count)
dist.init_process_group(device_id=device)

The line of device = torch.device(...) is anyway needed because user would use it for tensor creation etc.

cc @H-Huang @awgu @wanchaol @fegin @fduwjj @wz337 @wconstab @d4l3k @c-p-i-o @zhangxiaoli73 @Chao1Han @jgong5

cc @H-Huang @awgu @wanchaol @fegin @fduwjj @wz337 @wconstab @d4l3k @c-p-i-o

@pytorch-bot
Copy link

pytorch-bot bot commented Nov 18, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/140963

Note: Links to docs will display an error until the docs builds have been completed.

❗ 1 Active SEVs

There are 1 currently active SEVs. If your PR is affected, please view them below:

✅ No Failures

As of commit af71423 with merge base ffb9790 (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@pytorch-bot pytorch-bot bot added oncall: distributed Add this issue/PR to distributed oncall triage queue release notes: distributed (c10d) release notes category labels Nov 18, 2024
kwen2501 added a commit that referenced this pull request Nov 18, 2024
ghstack-source-id: a4a9ea9
Pull Request resolved: #140963
@kwen2501 kwen2501 added the ciflow/trunk Trigger trunk jobs on your pull request label Nov 18, 2024
@kwen2501 kwen2501 requested a review from wz337 November 18, 2024 18:42
# >>> init_process_group()
# we set it to `undefined` and rely on lazy init.
if backend is None:
backend = "undefined"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what does backend = "undefined" do? Is it going to throw an error inside Backend below? or it somehow finds a default one later?

Copy link
Contributor Author

@kwen2501 kwen2501 Nov 18, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It later get translated into "cuda:nccl,cpu:gloo" IIRC.
I guess I can make it go away? This PR just didn't touch that aspect.


backend_list = [UNDEFINED, GLOO, NCCL, UCC, MPI]

# 3rd-party devices can register the default backend support here
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you expect users to modify this dict directly or we will also add register/unregister API for third-party devices and backends? There are multiple dicts here and I guess register/unregister APIs can make them consistent without users' awareness and also more clear.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also need to consider to support privateusr1 device too. cc @shink

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@jgong5 Thanks for the ping! A register/unregister API will benefit all privateuse1 backends. I support this point.

cc: @FFFrog

Copy link
Contributor Author

@kwen2501 kwen2501 Nov 19, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, good idea. We can provide registration APIs, which would ideally be called when the a third-party module is imported so that no user involvement is needed. Let me add it in a next PR.

@wconstab
Copy link
Contributor

Lgtm

@kwen2501
Copy link
Contributor Author

@pytorchbot merge

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

pobin6 pushed a commit to pobin6/pytorch that referenced this pull request Dec 5, 2024
Citing @malfet's [comment](pytorch#136343 (review)) in pytorch#136343
> It would be great, if users do not have to modify their programs for every new backend, but rather use with torch.device('xpu'): and keep rest of the code unchanged.

This PR makes the backend specification ("nccl", "gloo") optional when user provides a `devce_id` to `init_process_group` (the acceptance of `device_id` has been previously supported for the purpose of eager init).

New user experience:
```
device = torch.device(device_type, rank % device_count)
dist.init_process_group(device_id=device)
```

The line of `device = torch.device(...)` is anyway needed because user would use it for tensor creation etc.

Pull Request resolved: pytorch#140963
Approved by: https://github.com/wconstab
@github-actions github-actions bot deleted the gh/kwen2501/95/head branch December 20, 2024 02:05
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ciflow/trunk Trigger trunk jobs on your pull request Merged oncall: distributed Add this issue/PR to distributed oncall triage queue release notes: distributed (c10d) release notes category

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants