-
Notifications
You must be signed in to change notification settings - Fork 25.7k
[DeviceMesh] Update get_group and add get_all_groups #128097
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/128097
Note: Links to docs will display an error until the docs builds have been completed. ❗ 1 Active SEVsThere are 1 currently active SEVs. If your PR is affected, please view them below: ❌ 1 New Failure, 4 Unrelated FailuresAs of commit a2640e5 with merge base 65aa16f ( NEW FAILURE - The following job has failed:
FLAKY - The following jobs failed but were likely due to flakiness present on trunk:
UNSTABLE - The following job failed but was likely due to flakiness present on trunk and has been marked as unstable:
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
@wconstab @wanchaol As suggested, I updated I believe most of the use cases I've seen use |
@wz337 has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator. |
return not_none( | ||
_find_pg_by_ranks_and_tag(*self._dim_group_infos[0][:2]) | ||
if self.mesh.ndim > 1 and mesh_dim is None: | ||
raise RuntimeError( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
do we want to raise the error here or just issue a deprecation warning?
- how many users do we think are already using this api and would hit this error
cc @wanchaol
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I have reviewed all the internal use cases that I can find by fbgs mesh.get_group(
(unless people named it some other stuff otherwise) and look like they are all either:
- already calling get_group with mesh_dim specified
- calling get_group on a 1D child mesh without mesh_dim specified
For these two cases, we are already returning a single PG anyway so these two cases won't be affected.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ok. it may be fine. if you want to derisk further against a revert, you could do a warning in this PR and stack a PR on top that changes the warning to an error. but i'll stamp to unblock.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ok. it may be fine. if you want to derisk further against a revert, you could do a warning in this PR and stack a PR on top that changes the warning to an error. but i'll stamp to unblock.
I think it should be fine. I am importing this PR as a diff to let internal tests run on it. Either way, even though if we do a warning, if it is returning a list at this point, it would also result in an error.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yeah I think either warning/error works, although this is a corner case, it's techniquelly BC breaking, we would probably need to put line in the release notes to explain the changes
torch/distributed/device_mesh.py
Outdated
""" | ||
Returns a list of ProcessGroups corresponding to the mesh dimensions, or | ||
returns a single ProcessGroup if mesh_dim is specified or the given mesh has | ||
Returns a single ProcessGroup if mesh_dim is specified or the given mesh has |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: rephrase as something like
"Returns the single ProcessGroup specified by mesh_dim
, or, if mesh_dim
is unspecified and the DeviceMesh is 1-dimensional, returns the only ProcessGroup in the mesh."
dim_groups = mesh.get_group() | ||
assert isinstance(dim_groups, list) | ||
return dim_groups[0] | ||
return dim_groups |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: its not groups
anymore, maybe just say return mesh.get_group()
instead?
2bd9ffc
to
7b0a7e9
Compare
@pytorchmergebot rebase |
@pytorchbot started a rebase job onto refs/remotes/origin/viable/strict. Check the current status here |
Successfully rebased |
7b0a7e9
to
8993ebe
Compare
@pytorchmergebot rebase |
@wz337 has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
lgtm! one nit inlined
return not_none( | ||
_find_pg_by_ranks_and_tag(*self._dim_group_infos[0][:2]) | ||
if self.mesh.ndim > 1 and mesh_dim is None: | ||
raise RuntimeError( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yeah I think either warning/error works, although this is a corner case, it's techniquelly BC breaking, we would probably need to put line in the release notes to explain the changes
@wz337 has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator. |
@pytorchmergebot merge |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
Merge failedReason: 1 jobs have failed, first few of them are: trunk / win-vs2019-cpu-py3 / build Details for Dev Infra teamRaised by workflow job |
@pytorchmergebot merge -i "unreleated trunk errror" |
❌ 🤖 pytorchbot command failed:
Try |
@pytorchmergebot merge -i |
Merge startedYour change will be merged while ignoring the following 5 checks: inductor-periodic / cuda12.4-py3.10-gcc9-sm86 / test (dynamic_inductor_torchbench, 2, 2, linux.g5.4xlarge.nvidia.gpu), inductor-periodic / cuda12.1-py3.10-gcc9-sm86-periodic-dynamo-benchmarks / test (dynamic_aot_eager_torchbench, 2, 2, linux.g5.4xlarge.nvidia.gpu), inductor / linux-jammy-cpu-py3.8-gcc11-inductor / test (inductor_torchbench_cpu_smoketest_perf, 1, 1, linux.24xl.spr-metal, unstable), inductor / cuda12.1-py3.10-gcc9-sm86 / test (dynamic_inductor_timm, 1, 2, linux.g5.4xlarge.nvidia.gpu), trunk / win-vs2019-cpu-py3 / build Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
Merge failedReason: 1 jobs have failed, first few of them are: trunk / linux-focal-cuda12.4-py3.10-gcc9-sm86 / test (default, 3, 5, linux.g5.4xlarge.nvidia.gpu) Details for Dev Infra teamRaised by workflow job |
@pytorchmergebot merge -i |
Merge startedYour change will be merged while ignoring the following 5 checks: inductor-periodic / cuda12.4-py3.10-gcc9-sm86 / test (dynamic_inductor_torchbench, 2, 2, linux.g5.4xlarge.nvidia.gpu), inductor-periodic / cuda12.1-py3.10-gcc9-sm86-periodic-dynamo-benchmarks / test (dynamic_aot_eager_torchbench, 2, 2, linux.g5.4xlarge.nvidia.gpu), inductor / linux-jammy-cpu-py3.8-gcc11-inductor / test (inductor_torchbench_cpu_smoketest_perf, 1, 1, linux.24xl.spr-metal, unstable), inductor / cuda12.1-py3.10-gcc9-sm86 / test (dynamic_inductor_timm, 1, 2, linux.g5.4xlarge.nvidia.gpu), trunk / linux-focal-cuda12.4-py3.10-gcc9-sm86 / test (default, 3, 5, linux.g5.4xlarge.nvidia.gpu) Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
Fixes pytorch#121984 Pull Request resolved: pytorch#128097 Approved by: https://github.com/wconstab, https://github.com/wanchaol
Fixes #121984
cc @mrshenli @pritamdamania87 @zhaojuanmao @satgera @gqchen @aazzolini @osalpekar @jiayisuse @H-Huang @kwen2501 @awgu @penguinwu @fegin @XilunWu @wanchaol @fduwjj @tianyu-l @wconstab @yf225 @chauhang @d4l3k @msaroufim