KEMBAR78
[c10d] Pass avoidRecordStreams into collective() function by kwen2501 · Pull Request #112195 · pytorch/pytorch · GitHub
Skip to content

Conversation

@kwen2501
Copy link
Contributor

Even after PR #111431, the collective(...) function still uses the underlined version avoidRecordStreams_ inside and does not respect each collective call's preference, as the underlined avoidRecordStreams_ is only controlled by environment variable.

As a fix, we pass avoidRecordStreams into the collective() function.

@pytorch-bot
Copy link

pytorch-bot bot commented Oct 26, 2023

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/112195

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit 4f0a652 with merge base 1b702b1 (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

Copy link
Collaborator

@awgu awgu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This makes sense to me!

@awgu
Copy link
Collaborator

awgu commented Oct 27, 2023

It looks like this broke correctness from the unit tests :/

@awgu awgu self-requested a review October 27, 2023 03:18
The sharded param is created by `pre_unshard` stream, then used by
`unshard` stream, then PG-NCCL's stream. A `record_stream` must be added
between `pre_unshard` and `unshard` if PG-NCCL stops doing such favor.
@awgu awgu added the ciflow/trunk Trigger trunk jobs on your pull request label Oct 28, 2023
Copy link
Collaborator

@awgu awgu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We discussed the FSDP change offline. FSDP was missing a record_stream call for correctness and was relying on ProcessGroupNCCL's own record_stream call, which is not technically part of the API's contract.

In other words, this change to ProcessGroupNCCL can lead to silent correctness failures, but it would only be for existing code that depended on unspecified behavior.

@awgu
Copy link
Collaborator

awgu commented Oct 28, 2023

@pytorchbot merge

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

xuhancn pushed a commit to xuhancn/pytorch that referenced this pull request Nov 7, 2023
…2195)

Even after PR pytorch#111431, the `collective(...)` function still uses the underlined version `avoidRecordStreams_` inside and does not respect each collective call's preference, as the underlined `avoidRecordStreams_` is only controlled by environment variable.

As a fix, we pass `avoidRecordStreams` into the collective() function.

Pull Request resolved: pytorch#112195
Approved by: https://github.com/awgu
Skylion007 pushed a commit to Skylion007/pytorch that referenced this pull request Nov 14, 2023
…2195)

Even after PR pytorch#111431, the `collective(...)` function still uses the underlined version `avoidRecordStreams_` inside and does not respect each collective call's preference, as the underlined `avoidRecordStreams_` is only controlled by environment variable.

As a fix, we pass `avoidRecordStreams` into the collective() function.

Pull Request resolved: pytorch#112195
Approved by: https://github.com/awgu
@github-actions github-actions bot deleted the fix_avoid_record_stream branch May 2, 2025 02:14
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ciflow/trunk Trigger trunk jobs on your pull request Merged release notes: distributed (c10d) release notes category

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants