KEMBAR78
[c10d] allow sub group to be eagerly inited even if default one is not by shuqiangzhang · Pull Request #138665 · pytorch/pytorch · GitHub
Skip to content

Conversation

@shuqiangzhang
Copy link
Contributor

@shuqiangzhang shuqiangzhang commented Oct 23, 2024

Stack from ghstack (oldest at bottom):

Summary:
Currently, eager mode is applied either to all PGs or NONE of them.
There are cases where we don't want to initialize the comms for default
PG, but we still want to initialize the comms for sub PG. Now with a
device_id passed to new group, we can achieve this case
Test Plan:
newly added UT

Tags:

Resolves #137018

cc @H-Huang @awgu @kwen2501 @wanchaol @fegin @fduwjj @wz337 @wconstab @d4l3k @c-p-i-o

Summary:
Currently, eager mode is applied either to all PGs or NONE of them.
There are cases where we don't want to initialize the comms for default
PG, but we still want to initialize the comms for sub PG. Now with a
device_id passed to new group, we can achieve this case
Test Plan:
newly added UT

Tags:

[ghstack-poisoned]
@pytorch-bot
Copy link

pytorch-bot bot commented Oct 23, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/138665

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit 2b515fd with merge base 8aedc64 (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@pytorch-bot pytorch-bot bot added the release notes: distributed (c10d) release notes category label Oct 23, 2024
shuqiangzhang added a commit that referenced this pull request Oct 23, 2024
Summary:
Currently, eager mode is applied either to all PGs or NONE of them.
There are cases where we don't want to initialize the comms for default
PG, but we still want to initialize the comms for sub PG. Now with a
device_id passed to new group, we can achieve this case
Test Plan:
newly added UT

Tags:

ghstack-source-id: b4e77ef
Pull Request resolved: #138665
@shuqiangzhang shuqiangzhang requested review from H-Huang, c-p-i-o, fduwjj, kwen2501 and wconstab and removed request for wconstab October 23, 2024 01:12
@kwen2501
Copy link
Contributor

kwen2501 commented Oct 23, 2024

I agree with the addition of device_id to the new_group API.
(I wonder if we should start calling the kwarg device instead of device_id -- it is actually a torch.device, rather than an id / index. Although, device_id makes the API aligned with init_process_group, so I am undecided.)

@kwen2501
Copy link
Contributor

Does this PR need to depend on #138518?

@shuqiangzhang
Copy link
Contributor Author

Does this PR need to depend on #138518?

Yes. otherwise, eager mode is coupled with split logic, aka, if subgroup is eager inited, it has to use split, which requires default PG to also be eager inited, which is against the intention of this PR

@shuqiangzhang
Copy link
Contributor Author

I agree with the addition of device_id to the new_group API. (I wonder if we should start calling the kwarg device instead of device_id -- it is actually a torch.device, rather than an id / index. Although, device_id makes the API aligned with init_process_group, so I am undecided.)

For the purpose of this PR, let's keep it consistent with the naming of init_process_group API? We could re-name all of them in other PRs.

@kwen2501
Copy link
Contributor

kwen2501 commented Oct 23, 2024

if subgroup is eager inited, it has to use split

A subgroup can be eagerly inited the same way as a default group is eagerly inited, without using split. I wonder if we could use that to implement this PR here?

That is:
(1) if parent group does not have bounded device id,
call ncclCommInitConfig eagerly for subgroup.
(2) if parent group has bounded device id,
call ncclCommSplit eagerly to create subgroup.

@kwen2501
Copy link
Contributor

kwen2501 commented Oct 23, 2024

We could re-name all of them in other PRs.

Once added this is a public argument. So there will be deprecation consequence if we'd like to change it in future. But again, I am okay with either name.

@kwen2501
Copy link
Contributor

let's keep it consistent with the naming of init_process_group API

Okay let's do that.

…t one is not"


Summary:
Currently, eager mode is applied either to all PGs or NONE of them.
There are cases where we don't want to initialize the comms for default
PG, but we still want to initialize the comms for sub PG. Now with a
device_id passed to new group, we can achieve this case
Test Plan:
newly added UT

Tags:

Resolves #137018

cc H-Huang awgu kwen2501 wanchaol fegin fduwjj wz337 wconstab d4l3k c-p-i-o

[ghstack-poisoned]
def rank(self) -> int: ...
def size(self) -> int: ...
def eager_connect_single_device(self, device: torch.device | None) -> None: ...
def is_initialized(self) -> bool: ...
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is this new method for? If it is for user, we can expose it; if it is just for testing, I think we should defer it.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It was Intended for both actually. It's a safer way to check if a PG is fully initialized and ready to split. Now we only allow split from eager init PG. But users can actually split a new PG from a non eager init PG if its is_initialized is true

Comment on lines 4778 to 4782
device_id (torch.device, optional): a single, specific device
to "bind" this process to, allowing for backend-specific
optimizations. only under NCCL: the communicator is immediately formed
(calling``ncclCommInit*`` immediately rather than the normal lazy
call)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This piece of documentation seems to come from init_process_group. I wonder if we could update them now that our view about device_id is clearer? For example:

a single, specific device to "bind" this process to. The `new_group` call will try to initialize a communication backend for the device if this field is given.

Comment on lines +4829 to +4830
if device_id is None:
device_id = default_pg.bound_device_id
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the current logic seems fine. Shall we also do a check here to make sure the given device_id is the same as default_pg.bound_device_id if they are both not None?

Comment on lines 389 to 394
// whether the backend is fully initialized, e.g., for NCCL, if the NCCL comms
// are fully initialized and ready to use.
virtual bool isInitialized() {
return false;
}

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could it be a bit unsafe to assume a default value here?
Maybe we should throw an unimplemented error here and force backends to implement it? But it will induce more work, including contacting 3rd-party backends to implement this.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If this is problemm We could move this interface to nccl backend only, similar to other APIs such as bound_device_id or abort

…t one is not"


Summary:
Currently, eager mode is applied either to all PGs or NONE of them.
There are cases where we don't want to initialize the comms for default
PG, but we still want to initialize the comms for sub PG. Now with a
device_id passed to new group, we can achieve this case
Test Plan:
newly added UT

Tags:

Resolves #137018

cc H-Huang awgu kwen2501 wanchaol fegin fduwjj wz337 wconstab d4l3k c-p-i-o

[ghstack-poisoned]
shuqiangzhang added a commit that referenced this pull request Oct 24, 2024
Summary:
Currently, eager mode is applied either to all PGs or NONE of them.
There are cases where we don't want to initialize the comms for default
PG, but we still want to initialize the comms for sub PG. Now with a
device_id passed to new group, we can achieve this case
Test Plan:
newly added UT

Tags:

ghstack-source-id: 914023b
Pull Request resolved: #138665
Copy link
Contributor

@kwen2501 kwen2501 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM overall. The comments are minor.


tensor = torch.full((1,), self.rank).cuda(device)
new_group = c10d.new_group([0, 1], device_id=device)
self.assertEqual(backend.comm_split_count(), 0)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: I propose we stop using comm_split_count for testing.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This one is widely used on split related tests, any alternatives?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No alternatives. Just proposing that we stop using it.

Copy link
Contributor Author

@shuqiangzhang shuqiangzhang Oct 24, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Then, not using it is basically equivalent to removing valid python tests and safety guards to future code changes. Unless we have an alternative, e.g., fully trusted c++ tests (but still missing the e2e py test), don't think it is a good idea to remove all of its usages

new_backend = new_group._get_backend(torch.device(device))
self.assertEqual(new_backend._is_initialized(), True)
dist.broadcast(tensor, 0, group=new_group)
self.assertEqual(new_backend.comm_split_count(), 0)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same here

dist.broadcast(tensor, 0, group=new_group)
self.assertEqual(new_backend.comm_split_count(), 0)
self.assertEqual(backend._is_initialized(), False)
torch.cuda.synchronize()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: is this synchronize necessary?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, because things could be aborted before collective is completed

@shuqiangzhang
Copy link
Contributor Author

@pytorchbot merge -f "no failures"

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged immediately since you used the force (-f) flag, bypassing any CI checks (ETA: 1-5 minutes). Please use -f as last resort and instead consider -i/--ignore-current to continue the merge ignoring current failures. This will allow currently pending tests to finish and report signal before the merge.

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

@github-actions github-actions bot deleted the gh/shuqiangzhang/54/head branch November 24, 2024 02:14
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Merged oncall: distributed Add this issue/PR to distributed oncall triage queue release notes: distributed (c10d) release notes category

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants