KEMBAR78
[FSDP2] Fix issue with set_reduce_scatter_divide_factor errors and MixedPrecisionPolicy by mori360 · Pull Request #155964 · pytorch/pytorch · GitHub
Skip to content

Conversation

@mori360
Copy link
Contributor

@mori360 mori360 commented Jun 13, 2025

@pytorch-bot
Copy link

pytorch-bot bot commented Jun 13, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/155964

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit ce6d607 with merge base 9620994 (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@pytorch-bot pytorch-bot bot added ciflow/inductor oncall: distributed Add this issue/PR to distributed oncall triage queue release notes: distributed (fsdp) release notes category labels Jun 13, 2025
@mori360 mori360 marked this pull request as ready for review June 16, 2025 17:03
@mori360 mori360 requested review from weifengpy and xunnanxu June 16, 2025 17:03
# uses NCCL's PreMulSum, which only allows data type half, float,
# or double. Set reduce_dtype as orig_dtype so that if won't be cast by
# mp_policy.
self._reduce_dtype = self._orig_dtype
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this breaks mixed precision contract? is it silently ignoring reduce_dtype from mp policy?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, the casted dtype by mp_policy cannot run under PreMulSum is they are not float.

@mori360 mori360 marked this pull request as draft June 17, 2025 03:39
@mori360 mori360 marked this pull request as ready for review June 17, 2025 17:07
@mori360 mori360 requested a review from weifengpy June 17, 2025 17:07
@skipIfRocm # regressed in ROCm 6.4, but ROCm 6.5 fixes it
@skip_if_lt_x_gpu(2)
@requires_nccl_version((2, 10), "Need NCCL 2.10+ for bf16 collectives")
def test_set_reduce_scatter_divide_factor(self):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

modify existing unit test instead?

def test_set_reduce_scatter_divide_factor(self):

return unpackPreMulSum<at::Half, ncclHalf>(reduceOp, comm);
case ncclFloat:
return unpackPreMulSum<float, ncclFloat>(reduceOp, comm);
case ncclBfloat16:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

add ProcessGroupNCCL unit test here?

reduce_scatter(output, tensor_lists, c10d.ReduceOp.SUM)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@kwen2501 seems nccl can take all the datatype they support. that's why we added ncclBfloat16 https://github.com/NVIDIA/nccl/blob/72d2432094d6ae36abd6e511c3a16a2d052dbf94/src/enqueue.cc#L2467-L2480

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

rebase on top of #155915

Copy link
Contributor

@weifengpy weifengpy left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

left minor comment about unit test

@weifengpy weifengpy requested a review from kwen2501 June 17, 2025 17:52
@mori360 mori360 marked this pull request as draft June 20, 2025 17:24
@mori360 mori360 added the ciflow/trunk Trigger trunk jobs on your pull request label Jul 3, 2025
@mori360 mori360 marked this pull request as ready for review July 3, 2025 17:30
@mori360 mori360 requested a review from weifengpy July 4, 2025 00:56
@mori360 mori360 changed the title Fix issue with set_reduce_scatter_divide_factor errors and MixedPrecisionPolicy [FSDP2] Fix issue with set_reduce_scatter_divide_factor errors and MixedPrecisionPolicy Jul 4, 2025
@mori360
Copy link
Contributor Author

mori360 commented Jul 7, 2025

@pytorchmergebot merge

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ciflow/inductor ciflow/trunk Trigger trunk jobs on your pull request Merged oncall: distributed Add this issue/PR to distributed oncall triage queue release notes: distributed (fsdp) release notes category

Projects

None yet

Development

Successfully merging this pull request may close these issues.

[FSDP2] set_reduce_scatter_divide_factor errors with non-trivial MixedPrecisionPolicy

3 participants