KEMBAR78
[PGNCCL] Fix behavior of destroy_process_group by kwen2501 · Pull Request #141510 · pytorch/pytorch · GitHub
Skip to content

Conversation

@kwen2501
Copy link
Contributor

@kwen2501 kwen2501 commented Nov 25, 2024

Stack from ghstack (oldest at bottom):

Today destroy_process_group() is implemented via ncclCommAbort.
When user call it in CPU, risk is that a healthy NCCL kernel gets preempted, which causes data corruption.

Instead of aborting kernels, we should flush collectives in destroy_process_group, i.e. let them complete normally, before we tear down resources.

This PR implements such "flushing" behavior using ncclCommFinalize, then reclaims resources via ncclCommDestroy.

Expected behaviors:
For a bad program, a hang is expected at destroy_process_group(). If the PG uses non-blocking communicators, such hang is recoverable, because we attaches a timeout to the flush behavior.

cc @H-Huang @awgu @wanchaol @fegin @fduwjj @wz337 @wconstab @d4l3k @c-p-i-o

@pytorch-bot pytorch-bot bot added oncall: distributed Add this issue/PR to distributed oncall triage queue release notes: distributed (c10d) release notes category labels Nov 25, 2024
@pytorch-bot
Copy link

pytorch-bot bot commented Nov 25, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/141510

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit ecdf475 with merge base 61dc5e9 (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@kwen2501 kwen2501 added keep-going Don't stop on first failure, keep running tests until the end ciflow/trunk Trigger trunk jobs on your pull request labels Nov 27, 2024
cc H-Huang awgu wanchaol fegin fduwjj wz337 wconstab d4l3k c-p-i-o

[ghstack-poisoned]
@kwen2501 kwen2501 changed the title [PGNCCL] Implement destroy behavior in shutdown() [PGNCCL] Fix behavior of destroy_process_group Dec 4, 2024
// Note: we have rewritten `shutdown` to represent the destroy behavior.
// Here we route to `abort()` explicitly to maintain the old behavior, until
// we fix everything.
abort();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

its good that this codepath tries to preserve legacy behavior, since otherwise we would risk introducing hangs on shutdown where they didn't exist before.

otoh, is shutdown() considered a public API surface itself?

Should we consider
(1) making 'wait' a flag for existing shutdown API (e.g. shutdown(wait_on_ops=False)) to make sure we always preserve BC
(2) just leave shutdown alone but add a new method for 'clean shutdown' and mark shutdown as deprecated?

Copy link
Contributor Author

@kwen2501 kwen2501 Dec 4, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

shutdown is not a public API per my understanding.
It was pybind'ed as _shutdown, and then used in dist.destroy_process_group.

.def(
"_shutdown",
[](const c10::intrusive_ptr<::c10d::ProcessGroupNCCL>& self) {
return self->shutdown();
},
py::call_guard<py::gil_scoped_release>())

(Thus the name shutdown doesn't really matter -- the fact it gets used by destroy_process_group means that it should carry the flush + destroy behavior -- no destroy is possible without waiting for ops to finish.)

For behavior like wait_on_ops=False, user should be directed to use abort_process_group instead, I think.

Today `destroy_process_group()` is implemented via `ncclCommAbort`. 
When user call it in CPU, risk is that a healthy NCCL kernel gets preempted, which causes data corruption.

Instead of aborting kernels, we should flush collectives in `destroy_process_group`, i.e. let them complete normally, before we tear down resources. 

This PR implements such "flushing" behavior using `ncclCommFinalize`, then reclaims resources via `ncclCommDestroy`.

Expected behaviors:
For a bad program, a hang is expected at `destroy_process_group()`. If the PG uses non-blocking communicators, such hang is recoverable, because we attaches a timeout to the flush behavior.

cc H-Huang awgu wanchaol fegin fduwjj wz337 wconstab d4l3k c-p-i-o

[ghstack-poisoned]
// Difference between `abort()` and `shutdown()`:
// 1. `abort()` will signal communicators to terminate all NCCL kernels
// immediately.
// 2. `shutdown()` will wait for all NCCL kernels to finish before destroying
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is shutdown blocking?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, it is blocking by purpose.

Copy link
Contributor

@fduwjj fduwjj left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you want to add a unit test for this?

@kwen2501
Copy link
Contributor Author

kwen2501 commented Dec 4, 2024

Do you want to add a unit test for this?

test_c10d_nccl.py has ~ 50 calls of destroy_process_group. I think we can rely on them to test this change.

@kwen2501
Copy link
Contributor Author

kwen2501 commented Dec 4, 2024

@pytorchbot merge -f "CI was green previously; new change just fixes typo"

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged immediately since you used the force (-f) flag, bypassing any CI checks (ETA: 1-5 minutes). Please use -f as last resort and instead consider -i/--ignore-current to continue the merge ignoring current failures. This will allow currently pending tests to finish and report signal before the merge.

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

pobin6 pushed a commit to pobin6/pytorch that referenced this pull request Dec 5, 2024
Today `destroy_process_group()` is implemented via `ncclCommAbort`.
When user call it in CPU, risk is that a healthy NCCL kernel gets preempted, which causes data corruption.

Instead of aborting kernels, we should flush collectives in `destroy_process_group`, i.e. let them complete normally, before we tear down resources.

This PR implements such "flushing" behavior using `ncclCommFinalize`, then reclaims resources via `ncclCommDestroy`.

Expected behaviors:
For a bad program, a hang is expected at `destroy_process_group()`. If the PG uses non-blocking communicators, such hang is recoverable, because we attaches a timeout to the flush behavior.

Pull Request resolved: pytorch#141510
Approved by: https://github.com/wconstab
AmdSampsa pushed a commit to AmdSampsa/pytorch that referenced this pull request Dec 9, 2024
Today `destroy_process_group()` is implemented via `ncclCommAbort`.
When user call it in CPU, risk is that a healthy NCCL kernel gets preempted, which causes data corruption.

Instead of aborting kernels, we should flush collectives in `destroy_process_group`, i.e. let them complete normally, before we tear down resources.

This PR implements such "flushing" behavior using `ncclCommFinalize`, then reclaims resources via `ncclCommDestroy`.

Expected behaviors:
For a bad program, a hang is expected at `destroy_process_group()`. If the PG uses non-blocking communicators, such hang is recoverable, because we attaches a timeout to the flush behavior.

Pull Request resolved: pytorch#141510
Approved by: https://github.com/wconstab
pytorchmergebot pushed a commit that referenced this pull request Dec 9, 2024
#141511)

Making CUDA or NCCL calls in object destruction can be dangerous because CUDA context may have exited before the the destructor, in which case, the CUDA calls would see a "CUDA driver shutting down" error.

this PR does take a destroy call away from NCCLComm dtor, and doesn't add a new one. If users are calling destroy_process_group or abort_process_group as recommended, then we are destroying for them, and otherwise we are OK with letting them possibly leak resources (and get a warning).

Pull Request resolved: #141511
Approved by: https://github.com/eqy, https://github.com/wconstab
ghstack dependencies: #141510
pytorchmergebot pushed a commit that referenced this pull request Dec 10, 2024
And removed some unnecessary conditions for calling `thread.join()` -- `thread.joinable()` should have covered it.

Pull Request resolved: #142297
Approved by: https://github.com/wconstab
ghstack dependencies: #141510, #141511
@github-actions github-actions bot deleted the gh/kwen2501/104/head branch January 4, 2025 02:06
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ciflow/trunk Trigger trunk jobs on your pull request keep-going Don't stop on first failure, keep running tests until the end Merged oncall: distributed Add this issue/PR to distributed oncall triage queue release notes: distributed (c10d) release notes category

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants