KEMBAR78
Seperate grad norm computation from `torch.nn.utils.clip_grad_norm_` · Issue #139467 · pytorch/pytorch · GitHub
Skip to content

Seperate grad norm computation from torch.nn.utils.clip_grad_norm_ #139467

@zijian-hu

Description

@zijian-hu

🚀 The feature, motivation and pitch

Gradient norm clipping requires computing the total gradient norm across the entire model first. The current design of torch.nn.utils.clip_grad_norm_ is insufficient for cases like pipeline parallelism (PP).

When using pipeline parallelism (PP), the grad norm need to be computed on each PP stage and then reduced across all PP stages. Separating the grad norm computation from torch.nn.utils.clip_grad_norm_ would allow developers to properly reduce the grad norm before clipping.

Alternatives

No response

Additional context

See pytorch/torchtitan#596 and pytorch/torchtitan#649 for more context.

cc @albanD @mruberry @jbschlosser @walterddr @mikaylagawarecki

Metadata

Metadata

Labels

module: nnRelated to torch.nntriagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate module

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions