KEMBAR78
ROCm context parallel backward lse not scaled · Issue #163958 · pytorch/pytorch · GitHub
Skip to content

ROCm context parallel backward lse not scaled #163958

@vexilligera

Description

@vexilligera

🐛 Describe the bug

based on this issue: #156012 and this PR: #156903

The fix patched forward but did not patch backward.
To patch backward, add
logsumexp /= 0.6931471805599453
at
https://github.com/ROCm/pytorch/blob/cfa0de7c5151cfd4d036b2b4ee6d35a37bd7a983/torch/distributed/tensor/experimental/_attention.py#L498

Versions

before patching gradient diff is 1e-1, after patching is 1e-7

cc @jeffdaily @sunway513 @jithunnair-amd @pruthvistony @ROCmSupport @dllehr-amd @jataylo @hongxiayang @naromero77amd

Metadata

Metadata

Assignees

Labels

module: rocmAMD GPU support for PytorchtriagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate module

Type

No type

Projects

Status

No status

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions