KEMBAR78
[C10D] Support group_dst/group_src in c10d send/recv by wconstab · Pull Request #140460 · pytorch/pytorch · GitHub
Skip to content

Conversation

@wconstab
Copy link
Contributor

@wconstab wconstab commented Nov 12, 2024

Stack from ghstack (oldest at bottom):

Partly addressing RFC 0042 (pytorch/rfcs#71)

It's annoying that 'dst' (for send) ust be a global rank even when a
group is passed in. But we can't easily change 'dst' without breaking
existing cases.

Furthermore, requiring use of 'global' dst breaks the less common usage
pattern of creating a new ProcessGroup object that is not connected to
the 'default group' and thus has no logical 'global' ranks.

cc @H-Huang @awgu @kwen2501 @wanchaol @fegin @fduwjj @wz337 @d4l3k @c-p-i-o

Addressing RFC 0042 (pytorch/rfcs#71)

It's annoying that 'dst' (for send) ust be a global rank even when a
group is passed in.  But we can't easily change 'dst' without breaking
existing cases.

[ghstack-poisoned]
@pytorch-bot
Copy link

pytorch-bot bot commented Nov 12, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/140460

Note: Links to docs will display an error until the docs builds have been completed.

❗ 1 Active SEVs

There are 1 currently active SEVs. If your PR is affected, please view them below:

✅ No Failures

As of commit f558d56 with merge base de34f58 (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@pytorch-bot pytorch-bot bot added oncall: distributed Add this issue/PR to distributed oncall triage queue release notes: distributed (c10d) release notes category labels Nov 12, 2024
wconstab added a commit that referenced this pull request Nov 12, 2024
Addressing RFC 0042 (pytorch/rfcs#71)

It's annoying that 'dst' (for send) ust be a global rank even when a
group is passed in.  But we can't easily change 'dst' without breaking
existing cases.

ghstack-source-id: 80af56e
Pull Request resolved: #140460
# TODO if we want to encourage migration to 'group_dst' arg and deprecate dst arg, should we support using 'group_dst'
# even when 'group' is None (for default group)?
assert group is not None, "Must specify group if using group_dst"
dst = get_global_rank(group, group_dst)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oops, this is dumb. i just realized that i can't implement group_dst as a function of global rank, since part of the reason for this PR is that 'group' might not be associated with the default-group at all, so get_global_rank function may not work. We will need to invert the logic below so that 'group_dst' is the thing that we directly pass into the group api, and we convert dst into group_dst at the top if dst is passed in. this will make the change a bit more invasive but still possible.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There is a global to group map mapping: _world.pg_group_ranks

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yea, i realize that but the point is we want to not rely on any global stuff when using group_ argument. This is to enable use cases where someone creates a ProcessGroup object directly and it's not registered with us.

Addressing RFC 0042 (pytorch/rfcs#71)

It's annoying that 'dst' (for send) ust be a global rank even when a
group is passed in.  But we can't easily change 'dst' without breaking
existing cases.

cc H-Huang awgu kwen2501 wanchaol fegin fduwjj wz337 d4l3k c-p-i-o

[ghstack-poisoned]
Addressing RFC 0042 (pytorch/rfcs#71)

It's annoying that 'dst' (for send) ust be a global rank even when a
group is passed in.  But we can't easily change 'dst' without breaking
existing cases.

cc H-Huang awgu kwen2501 wanchaol fegin fduwjj wz337 d4l3k c-p-i-o

[ghstack-poisoned]
Addressing RFC 0042 (pytorch/rfcs#71)

It's annoying that 'dst' (for send) ust be a global rank even when a
group is passed in.  But we can't easily change 'dst' without breaking
existing cases.

Furthermore, requiring use of 'global' dst breaks the less common usage
pattern of creating a new ProcessGroup object that is not connected to
the 'default group' and thus has no logical 'global' ranks.

[ghstack-poisoned]
wconstab added a commit that referenced this pull request Nov 14, 2024
Addressing RFC 0042 (pytorch/rfcs#71)

It's annoying that 'dst' (for send) ust be a global rank even when a
group is passed in.  But we can't easily change 'dst' without breaking
existing cases.

Furthermore, requiring use of 'global' dst breaks the less common usage
pattern of creating a new ProcessGroup object that is not connected to
the 'default group' and thus has no logical 'global' ranks.

ghstack-source-id: 72264a2
Pull Request resolved: #140460
Copy link
Member

@d4l3k d4l3k left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

code looks good to me

Can we add some unit tests on this new logic?

group_src_rank = get_group_rank(pg, src)
pg.recv([tensor], group_src_rank, tag).wait()
return src
group_src = _canonicalize_group_rank(group, src, group_src)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we use _check_not_self_rank(group, group_src, "source") here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hmm, no, we apparently can't. It causes this test to fail
python test/distributed/test_distributed_spawn.py TestDistBackendWithSpawn.test_batch_isend_irecv_nccl

The test looks kinda dumb, i'm not sure if it really even intends to send/recv to itself, but that's what it does. And I think (guess?) that nccl supports this so we probably shouldn't start asserting against it now...

@fduwjj
Copy link
Contributor

fduwjj commented Nov 14, 2024

lg! maybe n00b question, what could be a use case for this?

Copy link
Contributor

@kwen2501 kwen2501 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM. Thanks a lot.
Do you mind extending the change to isend, irecv, broadcast and reduce too?
Can be in a next PR.
Thanks!

Comment on lines 1129 to 1132
if group_rank is not None:
assert global_rank is None, "Can't specify both group_rank and global_rank"
else:
assert global_rank is not None, "Must specify global_rank or group_rank"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: raise ValueError instead? assert can be optimized away with -O flag.

@wconstab
Copy link
Contributor Author

lg! maybe n00b question, what could be a use case for this?

Well, first I wanted this when developing pipelining. We often want to send to our next pipeline rank using the pp group but we are forced to convert our pp rank to global rank which is inconvenient.

Second the RFC0042 issue came up and using local ranks is important if you don't want to assume all pgs are part of the same world group.

@kwen2501 yea planning to add other ops in additional PRs.

Addressing RFC 0042 (pytorch/rfcs#71)

It's annoying that 'dst' (for send) ust be a global rank even when a
group is passed in.  But we can't easily change 'dst' without breaking
existing cases.

Furthermore, requiring use of 'global' dst breaks the less common usage
pattern of creating a new ProcessGroup object that is not connected to
the 'default group' and thus has no logical 'global' ranks.

[ghstack-poisoned]
@wconstab wconstab changed the title [C10D] Support group_dst/group_src in c10d collectives [C10D] Support group_dst/group_src in c10d send/recv Nov 15, 2024
Partly addressing RFC 0042 (pytorch/rfcs#71)

It's annoying that 'dst' (for send) ust be a global rank even when a
group is passed in.  But we can't easily change 'dst' without breaking
existing cases.

Furthermore, requiring use of 'global' dst breaks the less common usage
pattern of creating a new ProcessGroup object that is not connected to
the 'default group' and thus has no logical 'global' ranks.

[ghstack-poisoned]
Partly addressing RFC 0042 (pytorch/rfcs#71)

It's annoying that 'dst' (for send) ust be a global rank even when a
group is passed in.  But we can't easily change 'dst' without breaking
existing cases.

Furthermore, requiring use of 'global' dst breaks the less common usage
pattern of creating a new ProcessGroup object that is not connected to
the 'default group' and thus has no logical 'global' ranks.

[ghstack-poisoned]
@kwen2501
Copy link
Contributor

Sorry I have more ops to suggggggest:
broadcast_object_list, scatter_object_list, send_object_list, recv_object_list.

Just a heads-up: in the implementation of these object coll functions, we should use the group_src arg for the inner calls. Otherwise there would be a double conversion :)

@wconstab
Copy link
Contributor Author

@kwen2501
Check later PRs
But yea maybe I have some double conversions

wconstab added a commit that referenced this pull request Nov 20, 2024
Changes semantic of __repr__ of P2POp: s, d are now group ranks instead
of global ranks. I think this is OK since I also updated the field names
to make this obvious.

Also add mypy annotations

Partially addresses RFC 0042 (pytorch/rfcs#71)
See more details/motivation in #140460

cc H-Huang awgu kwen2501 wanchaol fegin fduwjj wz337 d4l3k c-p-i-o

[ghstack-poisoned]
wconstab added a commit that referenced this pull request Nov 20, 2024
Changes semantic of __repr__ of P2POp: s, d are now group ranks instead
of global ranks. I think this is OK since I also updated the field names
to make this obvious.

Also add mypy annotations

Partially addresses RFC 0042 (pytorch/rfcs#71)
See more details/motivation in #140460

ghstack-source-id: 6991c41
Pull Request resolved: #141054
wconstab added a commit that referenced this pull request Nov 21, 2024
Changes semantic of __repr__ of P2POp: s, d are now group ranks instead
of global ranks. I think this is OK since I also updated the field names
to make this obvious.

Also add mypy annotations

Partially addresses RFC 0042 (pytorch/rfcs#71)
See more details/motivation in #140460

cc H-Huang awgu kwen2501 wanchaol fegin fduwjj wz337 d4l3k c-p-i-o

[ghstack-poisoned]
wconstab added a commit that referenced this pull request Nov 21, 2024
Changes semantic of __repr__ of P2POp: s, d are now group ranks instead
of global ranks. I think this is OK since I also updated the field names
to make this obvious.

Also add mypy annotations

Partially addresses RFC 0042 (pytorch/rfcs#71)
See more details/motivation in #140460

ghstack-source-id: 6f61786
Pull Request resolved: #141054
pytorchmergebot pushed a commit that referenced this pull request Nov 21, 2024
Changes semantic of __repr__ of P2POp: s, d are now group ranks instead
of global ranks. I think this is OK since I also updated the field names
to make this obvious.

Also add mypy annotations

Partially addresses RFC 0042 (pytorch/rfcs#71)
See more details/motivation in #140460

Pull Request resolved: #141054
Approved by: https://github.com/kwen2501
pobin6 pushed a commit to pobin6/pytorch that referenced this pull request Dec 5, 2024
Partly addressing RFC 0042 (pytorch/rfcs#71)

It's annoying that 'dst' (for send) ust be a global rank even when a
group is passed in.  But we can't easily change 'dst' without breaking
existing cases.

Furthermore, requiring use of 'global' dst breaks the less common usage
pattern of creating a new ProcessGroup object that is not connected to
the 'default group' and thus has no logical 'global' ranks.
Pull Request resolved: pytorch#140460
Approved by: https://github.com/d4l3k, https://github.com/kwen2501, https://github.com/fduwjj
pobin6 pushed a commit to pobin6/pytorch that referenced this pull request Dec 5, 2024
Avoid copypaste of send/isend and recv/irecv impl.

This does change the warning issued from send to include the identifier
"isend" instead of "send", but I think thats not a big deal.

Pull Request resolved: pytorch#140815
Approved by: https://github.com/fegin
ghstack dependencies: pytorch#140460
pobin6 pushed a commit to pobin6/pytorch that referenced this pull request Dec 5, 2024
…40820)

Faced with an annoying string of warnings like this when running tests,
<img width="1644" alt="Screenshot 2024-11-15 at 11 23 21 AM" src="https://github.com/user-attachments/assets/91ff4e1d-3c29-4510-9a61-46e7df68a212">

My choices seem to be (1) call destroy_process_group() at the end of
each test fn, (2) do this in some wrapper, (3) do it in the base test
class.

Since tests in MultiProcessTestCase are responsible for calling
init_process_group themselves, they should also be responsible for
calling destroy (or at least method (3) would be asymmetric and may
result in double-destroy).

But it doesn't feel worth it to go add a destroy call manually to each
test, and try/except for a possible second destroy call seems like a
happy middle ground.

Note: tests that want to ensure that destroy runs cleanly can and should
still call destroy _inside_ the test, and this change does not affect
that.

Pull Request resolved: pytorch#140820
Approved by: https://github.com/fegin
ghstack dependencies: pytorch#140460, pytorch#140815
pobin6 pushed a commit to pobin6/pytorch that referenced this pull request Dec 5, 2024
)

Also add missing mypy typing and a few asserts to make mypy happy

Partially addresses RFC 0042 (pytorch/rfcs#71)
See more details/motivation in pytorch#140460

Note: object collective version canonicalizes to global instead of group
rank, simply becuase this left more of the original code intact and
required less conversions overall.

Pull Request resolved: pytorch#140827
Approved by: https://github.com/kwen2501
pobin6 pushed a commit to pobin6/pytorch that referenced this pull request Dec 5, 2024
Also add mypy annotations

Partially addresses RFC 0042 (pytorch/rfcs#71)
See more details/motivation in pytorch#140460
Pull Request resolved: pytorch#140843
Approved by: https://github.com/kwen2501
pobin6 pushed a commit to pobin6/pytorch that referenced this pull request Dec 5, 2024
…orch#140847)

Also add mypy annotations

Partially addresses RFC 0042 (pytorch/rfcs#71)
See more details/motivation in pytorch#140460

Pull Request resolved: pytorch#140847
Approved by: https://github.com/H-Huang
ghstack dependencies: pytorch#140843
pobin6 pushed a commit to pobin6/pytorch that referenced this pull request Dec 5, 2024
…1054)

Changes semantic of __repr__ of P2POp: s, d are now group ranks instead
of global ranks. I think this is OK since I also updated the field names
to make this obvious.

Also add mypy annotations

Partially addresses RFC 0042 (pytorch/rfcs#71)
See more details/motivation in pytorch#140460

Pull Request resolved: pytorch#141054
Approved by: https://github.com/kwen2501
@github-actions github-actions bot deleted the gh/wconstab/361/head branch December 19, 2024 02:10
fightingand pushed a commit to fightingand/pytorch that referenced this pull request Dec 20, 2024
Addressing RFC 0042 (pytorch/rfcs#71)

It's annoying that 'dst' (for send) ust be a global rank even when a
group is passed in.  But we can't easily change 'dst' without breaking
existing cases.

Furthermore, requiring use of 'global' dst breaks the less common usage
pattern of creating a new ProcessGroup object that is not connected to
the 'default group' and thus has no logical 'global' ranks.

ghstack-source-id: 33ea136
Pull Request resolved: pytorch/pytorch#140460
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Merged oncall: distributed Add this issue/PR to distributed oncall triage queue release notes: distributed (c10d) release notes category

Projects

None yet

Development

Successfully merging this pull request may close these issues.

7 participants