KEMBAR78
Dynamo should prune non-live captured variables · Issue #127350 · pytorch/pytorch · GitHub
Skip to content

Dynamo should prune non-live captured variables  #127350

@zou3519

Description

@zou3519
import torch
@torch.compile(backend="eager", fullgraph=True)
def f(x):
   y = x.clone()
   def g():
      return y + 1
   return g()

g = f(torch.randn(4))

running the above with TORCH_LOGS=+graph_code gives the following graph:

def forward(self, L_x_: "f32[4]"):
    l_x_ = L_x_
    y: "f32[4]" = l_x_.clone();  l_x_ = None
    add: "f32[4]" = y + 1
    return (y, add)

Dynamo doesn't need to return y because it is not accessible from the outside of the compiled function. The reason why it does is:

  • every time Dynamo captures a Tensor, it writes a new side effect
  • If any Tensor participates in a live side effect, then Dynamo must return it as an output of the graph
  • This side effect never gets pruned, so Dynamo returns it as an output of the graph

This is potentially a performance issue for models that do a lot of capturing of variables

cc @ezyang @msaroufim @bdhirsh @anijain2305 @chauhang @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @chenyang78 @kadeng

Metadata

Metadata

Assignees

Labels

module: dynamooncall: pt2triagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate module

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions