KEMBAR78
FSDP + DTensor Loss Flatlines Randomly · Issue #117471 · pytorch/pytorch · GitHub
Skip to content

FSDP + DTensor Loss Flatlines Randomly #117471

@mvpatel2000

Description

@mvpatel2000

🐛 Describe the bug

We have been training dtensor off torch nightly (in anticipation for 2.2), and we are very often seeing the loss flatline. We do not see this at all on current nightly (as of 4 days ago), and at this point we are very confident there is a regression/bug in the current release candidate (for 2.2) that breaks FSDP training (at least with dtensor).
Our best guess is one of the two PRs linked fix it:

image

Versions

Torch 2.2 branch

cc @zhaojuanmao @mrshenli @rohan-varma @awgu @fegin @penguinwu @kwen2501 @wanchaol @XilunWu @tianyu-l

Metadata

Metadata

Assignees

No one assigned

    Labels

    module: dtensordistributed tensor tagmodule: fsdptriagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate module

    Type

    No type

    Projects

    No projects

    Milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions