KEMBAR78
[c10d] Remove deprecated multi-gpu-per-thread APIs by kwen2501 · Pull Request #114156 · pytorch/pytorch · GitHub
Skip to content

Conversation

@kwen2501
Copy link
Contributor

@kwen2501 kwen2501 commented Nov 20, 2023

As of today, PyTorch Distributed's preferred programming model is one device per thread, as exemplified by the APIs in its document. The multi-GPU functions (which stand for multiple GPUs per CPU thread) have been deprecated for three versions. Removing them now before 2.2 release.

cc @ezyang @gchanan

As of today, PyTorch Distributed's preferred programming model is one
device per thread, as exemplified by the APIs in its document.  The
multi-GPU functions (which stand for multiple GPUs per CPU thread) have
been deprecated for three versions. Removing them now before 2.2
release.
@kwen2501 kwen2501 requested a review from albanD as a code owner November 20, 2023 19:25
@pytorch-bot pytorch-bot bot added the release notes: distributed (c10d) release notes category label Nov 20, 2023
@pytorch-bot
Copy link

pytorch-bot bot commented Nov 20, 2023

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/114156

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit ef380df with merge base 140c54e (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

Copy link
Collaborator

@albanD albanD left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

public api doc change sounds good to me.

Copy link
Member

@H-Huang H-Huang left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good, I think there is a lot of follow up work that can be done (BE?)

We can probably comb through a lot of util functions and find which has an assumption for multiple GPUs and remove that logic to simplify our code. For example, in PGnccl getDeviceList() only makes sense for multigpus and we have a lot of, now unnecessary, logic of looping through devices.

seq_++;

// Currently, the API permits two scenarios where inputs.size() and
// Currently, the API permits one scenario where inputs.size() and
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we create a follow up task for removing the vector arguments for all collectives (e.g. std::vector<at::Tensor>& to at::Tensor&)? The only reason vector was added is for multi-gpu collectives right?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There is also the _coalesced python APIs (the "one scenario" left referred here).
And yeah, we'd need to think of a strategic way to remove them.

@kwen2501
Copy link
Contributor Author

@H-Huang correct, there are a lot of Better Engineering to do. This PR just removes the user-facing APIs, as a starting point. Next we could go through the backend implementation in ProcessGroupNCCL.cpp.

@kwen2501 kwen2501 added the suppress-bc-linter Suppresses the failures of API backward-compatibility linter (Lint/bc_linter) label Nov 20, 2023
@kwen2501
Copy link
Contributor Author

Adding suppress-bc-linter label because the bc break here is intentional.

nGPUs = torch.cuda.device_count()
visible_devices = range(nGPUs)

if backend == "nccl":
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are we going to remove init_multigpu_helper completely?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Likely not. It is used in a lot of places (even non-distributed tests). So I let it stay, because other tests may need it for other purpose.

Copy link
Contributor

@fduwjj fduwjj left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@kwen2501
Copy link
Contributor Author

@pytorchbot merge

@pytorch-bot pytorch-bot bot added the ciflow/trunk Trigger trunk jobs on your pull request label Nov 21, 2023
@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

@kit1980
Copy link
Contributor

kit1980 commented Nov 30, 2023

This has caused multiple issues in the Meta-internal pipelines.
In general we should make sure important internal usages are updated before removal: this can be done with the help from https://github.com/pytorch-labs/torchfix

@kit1980 kit1980 added module: bc-breaking Related to a BC-breaking change topic: bc breaking topic category topic: bc_breaking labels Nov 30, 2023
@github-actions github-actions bot deleted the remove_multigpu_apis branch February 19, 2024 01:59
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ciflow/trunk Trigger trunk jobs on your pull request Merged module: bc-breaking Related to a BC-breaking change module: deprecation release notes: distributed (c10d) release notes category suppress-bc-linter Suppresses the failures of API backward-compatibility linter (Lint/bc_linter) topic: bc breaking topic category topic: deprecation topic category

Projects

None yet

Development

Successfully merging this pull request may close these issues.

6 participants