KEMBAR78
[DeviceMesh] Update get_group and add get_all_groups by wz337 · Pull Request #128097 · pytorch/pytorch · GitHub
Skip to content

Conversation

wz337
Copy link
Contributor

@wz337 wz337 commented Jun 6, 2024

@pytorch-bot
Copy link

pytorch-bot bot commented Jun 6, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/128097

Note: Links to docs will display an error until the docs builds have been completed.

❗ 1 Active SEVs

There are 1 currently active SEVs. If your PR is affected, please view them below:

❌ 1 New Failure, 4 Unrelated Failures

As of commit a2640e5 with merge base 65aa16f (image):

NEW FAILURE - The following job has failed:

FLAKY - The following jobs failed but were likely due to flakiness present on trunk:

UNSTABLE - The following job failed but was likely due to flakiness present on trunk and has been marked as unstable:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@pytorch-bot pytorch-bot bot added ciflow/inductor oncall: distributed Add this issue/PR to distributed oncall triage queue release notes: distributed (fsdp) release notes category labels Jun 6, 2024
@wz337 wz337 changed the title [WIP]update get_group and add get_groups [DeviceMesh] Update get_group and add get_groups Jun 6, 2024
@wz337 wz337 requested review from wanchaol and wconstab June 6, 2024 16:59
@wz337 wz337 marked this pull request as ready for review June 6, 2024 16:59
@wz337
Copy link
Contributor Author

wz337 commented Jun 6, 2024

@wconstab @wanchaol As suggested, I updated get_group to return a single PG and added a new API for get_groups. I am wondering how we should warn the user about the change, since the API signature of get_group remains the same while the return type changes.

I believe most of the use cases I've seen use get_group with a mesh_dim passing in, so these use cases won't be affected. The few use cases where the list is returned is actually in our own code base or tests. Should we throw a warning in get_group when user does not pass in a mesh_dim and redirect them to get_groups if they are looking for the list of all the PGs?

@facebook-github-bot
Copy link
Contributor

@wz337 has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator.

@wz337 wz337 added module: dtensor distributed tensor tag and removed release notes: distributed (fsdp) release notes category labels Jun 6, 2024
@pytorch-bot pytorch-bot bot added the release notes: distributed (fsdp) release notes category label Jun 6, 2024
@wz337 wz337 added release notes: distributed (dtensor) release notes category and removed release notes: distributed (fsdp) release notes category labels Jun 6, 2024
return not_none(
_find_pg_by_ranks_and_tag(*self._dim_group_infos[0][:2])
if self.mesh.ndim > 1 and mesh_dim is None:
raise RuntimeError(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do we want to raise the error here or just issue a deprecation warning?

  • how many users do we think are already using this api and would hit this error

cc @wanchaol

Copy link
Contributor Author

@wz337 wz337 Jun 6, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have reviewed all the internal use cases that I can find by fbgs mesh.get_group( (unless people named it some other stuff otherwise) and look like they are all either:

  1. already calling get_group with mesh_dim specified
  2. calling get_group on a 1D child mesh without mesh_dim specified

For these two cases, we are already returning a single PG anyway so these two cases won't be affected.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok. it may be fine. if you want to derisk further against a revert, you could do a warning in this PR and stack a PR on top that changes the warning to an error. but i'll stamp to unblock.

Copy link
Contributor Author

@wz337 wz337 Jun 6, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok. it may be fine. if you want to derisk further against a revert, you could do a warning in this PR and stack a PR on top that changes the warning to an error. but i'll stamp to unblock.

I think it should be fine. I am importing this PR as a diff to let internal tests run on it. Either way, even though if we do a warning, if it is returning a list at this point, it would also result in an error.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah I think either warning/error works, although this is a corner case, it's techniquelly BC breaking, we would probably need to put line in the release notes to explain the changes

"""
Returns a list of ProcessGroups corresponding to the mesh dimensions, or
returns a single ProcessGroup if mesh_dim is specified or the given mesh has
Returns a single ProcessGroup if mesh_dim is specified or the given mesh has
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: rephrase as something like

"Returns the single ProcessGroup specified by mesh_dim, or, if mesh_dim is unspecified and the DeviceMesh is 1-dimensional, returns the only ProcessGroup in the mesh."

dim_groups = mesh.get_group()
assert isinstance(dim_groups, list)
return dim_groups[0]
return dim_groups
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: its not groups anymore, maybe just say return mesh.get_group() instead?

@wz337 wz337 force-pushed the fix_get_group branch 2 times, most recently from 2bd9ffc to 7b0a7e9 Compare June 6, 2024 22:31
@wz337
Copy link
Contributor Author

wz337 commented Jun 6, 2024

@pytorchmergebot rebase

@pytorchmergebot
Copy link
Collaborator

@pytorchbot started a rebase job onto refs/remotes/origin/viable/strict. Check the current status here

@pytorchmergebot
Copy link
Collaborator

Successfully rebased fix_get_group onto refs/remotes/origin/viable/strict, please pull locally before adding more changes (for example, via git checkout fix_get_group && git pull --rebase)

@wz337
Copy link
Contributor Author

wz337 commented Jun 7, 2024

@pytorchmergebot rebase

@wz337 wz337 changed the title [DeviceMesh] Update get_group and add get_groups [DeviceMesh] Update get_group and add get_all_groups Jun 7, 2024
@facebook-github-bot
Copy link
Contributor

@wz337 has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator.

Copy link
Collaborator

@wanchaol wanchaol left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lgtm! one nit inlined

return not_none(
_find_pg_by_ranks_and_tag(*self._dim_group_infos[0][:2])
if self.mesh.ndim > 1 and mesh_dim is None:
raise RuntimeError(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah I think either warning/error works, although this is a corner case, it's techniquelly BC breaking, we would probably need to put line in the release notes to explain the changes

@wanchaol wanchaol added the topic: bc breaking topic category label Jun 7, 2024
@facebook-github-bot
Copy link
Contributor

@wz337 has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator.

@wz337
Copy link
Contributor Author

wz337 commented Jun 8, 2024

@pytorchmergebot merge

@pytorch-bot pytorch-bot bot added the ciflow/trunk Trigger trunk jobs on your pull request label Jun 8, 2024
@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

@pytorchmergebot
Copy link
Collaborator

Merge failed

Reason: 1 jobs have failed, first few of them are: trunk / win-vs2019-cpu-py3 / build

Details for Dev Infra team Raised by workflow job

@wz337
Copy link
Contributor Author

wz337 commented Jun 8, 2024

@pytorchmergebot merge -i "unreleated trunk errror"

@pytorch-bot
Copy link

pytorch-bot bot commented Jun 8, 2024

❌ 🤖 pytorchbot command failed:

@pytorchbot: error: unrecognized arguments: unreleated trunk errror

usage: @pytorchbot [-h] {merge,revert,rebase,label,drci,cherry-pick,close} ...

Try @pytorchbot --help for more info.

@wz337
Copy link
Contributor Author

wz337 commented Jun 8, 2024

@pytorchmergebot merge -i

@pytorchmergebot
Copy link
Collaborator

Merge failed

Reason: 1 jobs have failed, first few of them are: trunk / linux-focal-cuda12.4-py3.10-gcc9-sm86 / test (default, 3, 5, linux.g5.4xlarge.nvidia.gpu)

Details for Dev Infra team Raised by workflow job

@wz337
Copy link
Contributor Author

wz337 commented Jun 8, 2024

@pytorchmergebot merge -i

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ciflow/inductor ciflow/trunk Trigger trunk jobs on your pull request Merged module: dtensor distributed tensor tag oncall: distributed Add this issue/PR to distributed oncall triage queue release notes: distributed (dtensor) release notes category topic: bc breaking topic category

Projects

None yet

Development

Successfully merging this pull request may close these issues.

[DeviceMesh] get_group() docs and behavior inconsistent for mesh_dim=None

5 participants