KEMBAR78
fix: [FSDP2] reshard_after_forward=False for root model by weifengpy · Pull Request #464 · NVIDIA-NeMo/RL · GitHub
Skip to content

Conversation

@weifengpy
Copy link
Contributor

What does this PR do ?

Hi from pytorch fsdp2! set fully_shard(reshard_after_forward=False) to keep the memory behavior the same after pytorch side change: pytorch/pytorch#154704

for root model, reshard_after_forward=False means keep root parameters unsharded after forward, since they will be used in the backward immeidately. This is a AA change

Issues

List issues that this PR closes (syntax):

Usage

  • You can potentially add a usage example below
uv run python examples/run_grpo_math.py cluster.gpus_per_node=8

Before your PR is "Ready for review"

Pre checks:

  • Make sure you read and followed Contributor guidelines
  • Did you write any new necessary tests?
  • Did you run the unit tests and functional tests locally? Visit our Testing Guide for how to run tests
  • Did you add or update any necessary documentation? Visit our Document Development Guide for how to write, build and test the docs.

Additional Information

  • ...

@weifengpy
Copy link
Contributor Author

would love to support more from pytorch fsdp2 side @gshennvm @yuki-666 @terrykong @parthchadha

@gshennvm
Copy link
Contributor

gshennvm commented Jun 2, 2025

would love to support more from pytorch fsdp2 side @gshennvm @yuki-666 @terrykong @parthchadha

thanks for the contribution! Just for my understanding -- doesn't pytorch set this to False by default already?

from the pytorch docs:

The root FSDP state has its value specially set to False as a heuristic since its parameters would typically be immediately all-gathered for backward.

cc @terrykong on how we can work together for better fsdp2 integration :)

@weifengpy
Copy link
Contributor Author

weifengpy commented Jun 2, 2025

doesn't pytorch set this to False by default already?

from the pytorch docs:

The root FSDP state has its value specially set to False as a heuristic since its parameters would typically be immediately all-gathered for backward.

I am about to remove the heuristic for future pytorch release: pytorch/pytorch#154704

for root model, if user set fully_shard(reshard_after_forward=True),

  • in existing pytorch release, we will override it to False. it's too implicit
  • in future pytorch release, we respect user's config on reshard_after_forward=True

I will change the doc as well to recommend reshard_after_forward=False for root model

gshennvm
gshennvm previously approved these changes Jun 2, 2025
Copy link
Contributor

@gshennvm gshennvm left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ah that makes sense. Thanks for the info! Approved.

We can merge tests are resolved

@terrykong terrykong changed the title [FSDP2] reshard_after_forward=False for root model fix: [FSDP2] reshard_after_forward=False for root model Jun 2, 2025
@terrykong
Copy link
Contributor

Hi @weifengpy . Thanks, could you rebase your commits to --signoff to pass our DCO check?

@weifengpy
Copy link
Contributor Author

Hi @weifengpy . Thanks, could you rebase your commits to --signoff to pass our DCO check?

just commited with --signoff

weifengpy added 2 commits June 2, 2025 20:32
Signed-off-by: Wei Feng <weif@meta.com>
Signed-off-by: Wei Feng <weif@meta.com>
@terrykong terrykong enabled auto-merge June 3, 2025 06:11
@terrykong terrykong added this pull request to the merge queue Jun 3, 2025
Merged via the queue into NVIDIA-NeMo:main with commit a1bf952 Jun 3, 2025
13 of 14 checks passed
YzjiaoNvd pushed a commit to YzjiaoNvd/NeMo-RL that referenced this pull request Jun 10, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants