KEMBAR78
[c10d] fix sequence numbers for coalesced operations by c-p-i-o · Pull Request #135132 · pytorch/pytorch · GitHub
Skip to content

Conversation

@c-p-i-o
Copy link
Contributor

@c-p-i-o c-p-i-o commented Sep 4, 2024

Summary:
We were erroneously incrementing seq_collective for p2p operations.
FIxes issue #134833

Test Plan:
Unit tests.
TODO: add more unit tests

Reviewers:

Subscribers:

Tasks:

Tags:

Fixes #ISSUE_NUMBER

cc @XilunWu @H-Huang @awgu @kwen2501 @wanchaol @fegin @fduwjj @wz337 @wconstab @d4l3k

@pytorch-bot
Copy link

pytorch-bot bot commented Sep 4, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/135132

Note: Links to docs will display an error until the docs builds have been completed.

❌ 4 New Failures

As of commit 3031918 with merge base 66db61f (image):

NEW FAILURES - The following jobs have failed:

  • linux-binary-libtorch-pre-cxx11 / libtorch-cpu-shared-with-deps-pre-cxx11-test / test (gh)
    RuntimeError: recursive_directory_iterator in used pre-CXX11 binaries, see; ['std::filesystem::recursive_directory_iterator::recursion_pending() const', 'std::filesystem::recursive_directory_iterator::depth() const', 'std::filesystem::recursive_directory_iterator::options() const', 'std::filesystem::recursive_directory_iterator::operator*() const', 'std::filesystem::recursive_directory_iterator::disable_recursion_pending()', 'std::filesystem::recursive_directory_iterator::pop(std::error_code&)', 'std::filesystem::recursive_directory_iterator::pop()', 'std::filesystem::recursive_directory_iterator::pop() [clone .cold]', 'std::filesystem::recursive_directory_iterator::increment(std::error_code&)', 'std::filesystem::recursive_directory_iterator::increment(std::error_code&) [clone .cold]', 'std::filesystem::recursive_directory_iterator::operator=(std::filesystem::recursive_directory_iterator&&)', 'std::filesystem::recursive_directory_iterator::operator=(std::filesystem::recursive_directory_iterator const&)', 'std::filesystem::recursive_directory_iterator::recursive_directory_iterator(std::filesystem::path const&, std::filesystem::directory_options, std::error_code*)', 'std::filesystem::recursive_directory_iterator::recursive_directory_iterator(std::filesystem::path const&, std::filesystem::directory_options, std::error_code*)', 'std::filesystem::recursive_directory_iterator::recursive_directory_iterator(std::filesystem::path const&, std::filesystem::directory_options, std::error_code*) [clone .cold]', 'std::filesystem::recursive_directory_iterator::~recursive_directory_iterator()', 'std::filesystem::recursive_directory_iterator::~recursive_directory_iterator()', 'std::filesystem::recursive_directory_iterator::operator++()', 'std::filesystem::recursive_directory_iterator::operator++() [clone .cold]']
  • linux-binary-manywheel / manywheel-py3_9-cuda11_8-test / test (gh)
    RuntimeError: recursive_directory_iterator in used pre-CXX11 binaries, see; ['std::filesystem::recursive_directory_iterator::recursion_pending() const', 'std::filesystem::recursive_directory_iterator::depth() const', 'std::filesystem::recursive_directory_iterator::options() const', 'std::filesystem::recursive_directory_iterator::operator*() const', 'std::filesystem::recursive_directory_iterator::disable_recursion_pending()', 'std::filesystem::recursive_directory_iterator::pop(std::error_code&)', 'std::filesystem::recursive_directory_iterator::pop()', 'std::filesystem::recursive_directory_iterator::pop() [clone .cold]', 'std::filesystem::recursive_directory_iterator::increment(std::error_code&)', 'std::filesystem::recursive_directory_iterator::increment(std::error_code&) [clone .cold]', 'std::filesystem::recursive_directory_iterator::operator=(std::filesystem::recursive_directory_iterator&&)', 'std::filesystem::recursive_directory_iterator::operator=(std::filesystem::recursive_directory_iterator const&)', 'std::filesystem::recursive_directory_iterator::recursive_directory_iterator(std::filesystem::path const&, std::filesystem::directory_options, std::error_code*)', 'std::filesystem::recursive_directory_iterator::recursive_directory_iterator(std::filesystem::path const&, std::filesystem::directory_options, std::error_code*)', 'std::filesystem::recursive_directory_iterator::recursive_directory_iterator(std::filesystem::path const&, std::filesystem::directory_options, std::error_code*) [clone .cold]', 'std::filesystem::recursive_directory_iterator::~recursive_directory_iterator()', 'std::filesystem::recursive_directory_iterator::~recursive_directory_iterator()', 'std::filesystem::recursive_directory_iterator::operator++()', 'std::filesystem::recursive_directory_iterator::operator++() [clone .cold]']
  • linux-binary-manywheel / manywheel-py3_9-cuda12_1-test / test (gh)
    RuntimeError: recursive_directory_iterator in used pre-CXX11 binaries, see; ['std::filesystem::recursive_directory_iterator::recursion_pending() const', 'std::filesystem::recursive_directory_iterator::depth() const', 'std::filesystem::recursive_directory_iterator::options() const', 'std::filesystem::recursive_directory_iterator::operator*() const', 'std::filesystem::recursive_directory_iterator::disable_recursion_pending()', 'std::filesystem::recursive_directory_iterator::pop(std::error_code&)', 'std::filesystem::recursive_directory_iterator::pop()', 'std::filesystem::recursive_directory_iterator::pop() [clone .cold]', 'std::filesystem::recursive_directory_iterator::increment(std::error_code&)', 'std::filesystem::recursive_directory_iterator::increment(std::error_code&) [clone .cold]', 'std::filesystem::recursive_directory_iterator::operator=(std::filesystem::recursive_directory_iterator&&)', 'std::filesystem::recursive_directory_iterator::operator=(std::filesystem::recursive_directory_iterator const&)', 'std::filesystem::recursive_directory_iterator::recursive_directory_iterator(std::filesystem::path const&, std::filesystem::directory_options, std::error_code*)', 'std::filesystem::recursive_directory_iterator::recursive_directory_iterator(std::filesystem::path const&, std::filesystem::directory_options, std::error_code*)', 'std::filesystem::recursive_directory_iterator::recursive_directory_iterator(std::filesystem::path const&, std::filesystem::directory_options, std::error_code*) [clone .cold]', 'std::filesystem::recursive_directory_iterator::~recursive_directory_iterator()', 'std::filesystem::recursive_directory_iterator::~recursive_directory_iterator()', 'std::filesystem::recursive_directory_iterator::operator++()', 'std::filesystem::recursive_directory_iterator::operator++() [clone .cold]']
  • linux-binary-manywheel / manywheel-py3_9-cuda12_4-test / test (gh)
    RuntimeError: recursive_directory_iterator in used pre-CXX11 binaries, see; ['std::filesystem::recursive_directory_iterator::recursion_pending() const', 'std::filesystem::recursive_directory_iterator::depth() const', 'std::filesystem::recursive_directory_iterator::options() const', 'std::filesystem::recursive_directory_iterator::operator*() const', 'std::filesystem::recursive_directory_iterator::disable_recursion_pending()', 'std::filesystem::recursive_directory_iterator::pop(std::error_code&)', 'std::filesystem::recursive_directory_iterator::pop()', 'std::filesystem::recursive_directory_iterator::pop() [clone .cold]', 'std::filesystem::recursive_directory_iterator::increment(std::error_code&)', 'std::filesystem::recursive_directory_iterator::increment(std::error_code&) [clone .cold]', 'std::filesystem::recursive_directory_iterator::operator=(std::filesystem::recursive_directory_iterator&&)', 'std::filesystem::recursive_directory_iterator::operator=(std::filesystem::recursive_directory_iterator const&)', 'std::filesystem::recursive_directory_iterator::recursive_directory_iterator(std::filesystem::path const&, std::filesystem::directory_options, std::error_code*)', 'std::filesystem::recursive_directory_iterator::recursive_directory_iterator(std::filesystem::path const&, std::filesystem::directory_options, std::error_code*)', 'std::filesystem::recursive_directory_iterator::recursive_directory_iterator(std::filesystem::path const&, std::filesystem::directory_options, std::error_code*) [clone .cold]', 'std::filesystem::recursive_directory_iterator::~recursive_directory_iterator()', 'std::filesystem::recursive_directory_iterator::~recursive_directory_iterator()', 'std::filesystem::recursive_directory_iterator::operator++()', 'std::filesystem::recursive_directory_iterator::operator++() [clone .cold]']

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@pytorch-bot pytorch-bot bot added the oncall: distributed Add this issue/PR to distributed oncall triage queue label Sep 4, 2024
@c-p-i-o c-p-i-o self-assigned this Sep 4, 2024
@pytorch-bot pytorch-bot bot added the release notes: distributed (c10d) release notes category label Sep 4, 2024
@c-p-i-o c-p-i-o force-pushed the cpio/fix_seq_nums_for_coalescing branch 2 times, most recently from 57086fe to d0d413c Compare September 4, 2024 22:36
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

does the approach in this PR assume that we can put a ++ in every place that we start coalescing?

i think this fails due to the fact that user-code can also start a coalescing group (e.g. from python side) and they will not bump the counter.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was assuming that from the python side, the coalescing group will eventually call ProcessGroupNCCL::collectiveCoalesced function.
In this function, we already do:
seqCollective_++.

Does the python side not eventually call the mentioned function above?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I should mention that the intent here is to fix the two use cases:

  1. all_gather of different outputTensor sizes automatically turns on coalescing and
  2. reduce_scatter of different inputTensor_ sizes.

Whenever python calls a coalesced operation, the call (if I read the code correctly) makes it to collectiveCoalesced function.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I changed the code go increment seqCollective_ in ProcessGroupNCCL::collective operation once per coalesced collective. This way, it should cover all API calls into this area.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does the python side not eventually call the mentioned function above?

Unfortunately there are 2 paths from python. For some of the python coalescing, it starts by calling 'start' and then issuing normal coalesced calls, then calling 'end'.
https://github.com/pytorch/pytorch/blob/main/torch/distributed/distributed_c10d.py#L2274

increment seqCollective_ in ProcessGroupNCCL::collective operation once per coalesced collective

iirc this is actually a change in semantics. We need to document what the semantics for the seq numbers are. One rationale is, we want one seq number per actual 'work' object == actual GPU kernel. Since coalescing groups logical operations together into one actual gpu kernel, we also create only one 'work' obj that we enqueue into the watchdog for monitoring. It may be confusing to increment the seqnum for each collective inside the coalescing group.

For P2P + flightrecorder i remember confronting this and adding a new 'op id' or something, which does increment on every op, but was separated from the seq number. The FR could then use actual seq num to match whole coalescing groups with each other.

@c-p-i-o c-p-i-o force-pushed the cpio/fix_seq_nums_for_coalescing branch 3 times, most recently from 3cf699b to 8663689 Compare September 5, 2024 03:32
@c-p-i-o c-p-i-o marked this pull request as ready for review September 5, 2024 03:59
@c-p-i-o c-p-i-o requested a review from fduwjj September 5, 2024 03:59
@c-p-i-o c-p-i-o force-pushed the cpio/fix_seq_nums_for_coalescing branch 2 times, most recently from 57fe706 to 18d13d6 Compare September 5, 2024 22:21
Copy link
Contributor

@fduwjj fduwjj left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice!

input_sizes = op_sizes[seq % ops_per_repeat]
profiling_name = "nccl:recv 0<-1" if self.rank == 0 else "nccl:send 1->0"
self.assertEqual(t["entries"][seq]["profiling_name"], profiling_name)
# we don't increment collective_seq_id for p2p ops.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hmm. one more possible wrinkle, since i am clearly trying to be annoying here, is we actually do use the 'collective' communicator when doing coalesced p2ps. In other words, i think it might have been intentional to bump the collective seq for coalp2p. Someone check me on this?

ref: we get a special nccl stream and nccl comm for p2p ops, but iirc we use the normal stream/comm for coalp2p same as we would use for collectives.

given this, what makes the most sense from a logging pov? clearly it will confuse users if we increment collective seq when we do p2p. otoh, if we are trying to do trace analysis to debug a hang, it might be important to note that this coalp2p happened sequentially after a collective on the same ncclcomm.

Summary:
We were erroneously incrementing seq_collective for p2p operations.
Fixes issue #134833

Test Plan:
Unit tests.
TODO: add more unit tests

Reviewers:

Subscribers:

Tasks:

Tags:
@GSSBMW
Copy link

GSSBMW commented Sep 23, 2024

Hi, @c-p-i-o is this PR ready to merge? Thanks!

@c-p-i-o
Copy link
Contributor Author

c-p-i-o commented Sep 25, 2024

Hi, @c-p-i-o is this PR ready to merge? Thanks!

I was trying to address @wconstab comment above

we actually do use the 'collective' communicator when doing coalesced p2ps

If the 'collective' communicator hangs when doing a coalesced p2p, we might falsely infer that the hang was due to a collective instead of a coalesced p2p. I'm thinking of introducing (another) flag that says if the collective is a coalesced op or not.
I might do that in a subsequent change though.

@GSSBMW
Copy link

GSSBMW commented Oct 7, 2024

Hi, @c-p-i-o
Any remaining issue to be fixed? Or already to merge?

I was trying to address @wconstab comment above

I don't get the point in @wconstab 's comment based on current PR. So not sure whether it has resolved all concerns.

Thanks!

@c-p-i-o
Copy link
Contributor Author

c-p-i-o commented Oct 8, 2024

@pytorchbot merge

@pytorch-bot pytorch-bot bot added the ciflow/trunk Trigger trunk jobs on your pull request label Oct 8, 2024
@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

@pytorchmergebot
Copy link
Collaborator

Merge failed

Reason: 1 jobs have failed, first few of them are: linux-binary-libtorch-pre-cxx11 / libtorch-cpu-shared-with-deps-pre-cxx11-test / test

Details for Dev Infra team Raised by workflow job

@c-p-i-o
Copy link
Contributor Author

c-p-i-o commented Oct 8, 2024

@pytorchbot merge -i

"Unrelated failure" here: inux-binary-libtorch-pre-cxx11 / libtorch-cpu-shared-with-deps-pre-cxx11-test / test

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged while ignoring the following 1 checks: linux-binary-libtorch-pre-cxx11 / libtorch-cpu-shared-with-deps-pre-cxx11-test / test

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

@pytorchmergebot
Copy link
Collaborator

Merge failed

Reason: 1 jobs have failed, first few of them are: linux-binary-manywheel / manywheel-py3_9-cuda12_4-test / test

Details for Dev Infra team Raised by workflow job

@c-p-i-o
Copy link
Contributor Author

c-p-i-o commented Oct 8, 2024

@pytorchbot merge -i

Failures unrelated to the change. linux-binary-manywheel / manywheel-py3_9-cuda12_4-test / test
recursive-directory-iterator failures don't have anything to do with this change.

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged while ignoring the following 5 checks: linux-binary-manywheel / manywheel-py3_9-cuda12_4-test / test, linux-binary-manywheel / manywheel-py3_9-cuda12_1-test / test, linux-binary-manywheel / manywheel-py3_9-cuda11_8-test / test, linux-binary-libtorch-pre-cxx11 / libtorch-cpu-shared-with-deps-pre-cxx11-test / test, trunk / linux-focal-cuda12.4-py3.10-gcc9-sm86 / build

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

pytorchmergebot pushed a commit that referenced this pull request Oct 10, 2024
Summary:
Fix sequence number in execution trace dump for matching between collective/p2p op and wait in execution trace replay.

`ProcessGroupNCCL` has 2 sequence number counter, `seqCollective_` and `seqP2P_`.
https://github.com/pytorch/pytorch/blob/b18ba9419e7062acbd49bef5c388e1b1d6a170dc/torch/csrc/distributed/c10d/ProcessGroupNCCL.hpp#L1188-L1191
However, `WorkNCCL` only has one sequence number member `seq_`. https://github.com/pytorch/pytorch/blob/b18ba9419e7062acbd49bef5c388e1b1d6a170dc/torch/csrc/distributed/c10d/ProcessGroupNCCL.hpp#L387
We need to match collective and p2p with wait separately.
facebookresearch/param@29b5a46

Depend on: #135132

Test Plan: buck2 run mode/dev-nosan kineto/libkineto/fb/integration_tests:pytorch_execution_trace_integration_test

Differential Revision:

Pull Request resolved: #134578
Approved by: https://github.com/kwen2501, https://github.com/c-p-i-o
@github-actions github-actions bot deleted the cpio/fix_seq_nums_for_coalescing branch November 29, 2024 02:10
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ciflow/trunk Trigger trunk jobs on your pull request Merged oncall: distributed Add this issue/PR to distributed oncall triage queue release notes: distributed (c10d) release notes category

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants