KEMBAR78
[fx] split_module fails with no-op graph and keep_original_order=True · Issue #140014 · pytorch/pytorch · GitHub
Skip to content

[fx] split_module fails with no-op graph and keep_original_order=True #140014

@kshitij12345

Description

@kshitij12345

🐛 Describe the bug

import torch
from torch.fx.experimental.proxy_tensor import make_fx
from torch.fx.passes.split_module import split_module

def fn(x):
    return (x,)

g = make_fx(fn, tracing_mode="fake")(torch.randn(3, 3))

g.print_readable()

# `keep_original_order=False` works
# split_module(g, None, split_callback=lambda _ : 0, keep_original_order=False)

# This fails
split_module(g, None, split_callback=lambda _ : 0, keep_original_order=True)

Error

Traceback (most recent call last):
  File "/home/kkalambarkar/lightning-thunder/scratchpad/test_split_module.py", line 13, in <module>
    split_module(g, None, split_callback=lambda _ : 0, keep_original_order=True)
  File "/home/kkalambarkar/git/pytorch/torch/fx/passes/split_module.py", line 607, in split_module
    torch.fx.graph.map_arg(node.args[0], lambda n: base_mod_env[n.name])
  File "/home/kkalambarkar/git/pytorch/torch/fx/node.py", line 880, in map_arg
    return map_aggregate(a, lambda x: fn(x) if isinstance(x, Node) else x)
  File "/home/kkalambarkar/git/pytorch/torch/fx/node.py", line 889, in map_aggregate
    t = tuple([map_aggregate(elem, fn) for elem in a])
  File "/home/kkalambarkar/git/pytorch/torch/fx/node.py", line 889, in <listcomp>
    t = tuple([map_aggregate(elem, fn) for elem in a])
  File "/home/kkalambarkar/git/pytorch/torch/fx/node.py", line 906, in map_aggregate
    return fn(a)
  File "/home/kkalambarkar/git/pytorch/torch/fx/node.py", line 880, in <lambda>
    return map_aggregate(a, lambda x: fn(x) if isinstance(x, Node) else x)
  File "/home/kkalambarkar/git/pytorch/torch/fx/passes/split_module.py", line 607, in <lambda>
    torch.fx.graph.map_arg(node.args[0], lambda n: base_mod_env[n.name])
KeyError: 'x_1'

Following patch seems to do the trick (do this seem like a good fix?)

diff --git a/torch/fx/passes/split_module.py b/torch/fx/passes/split_module.py
index 0495a9520f6..2169300a922 100644
--- a/torch/fx/passes/split_module.py
+++ b/torch/fx/passes/split_module.py
@@ -601,6 +601,12 @@ def split_module(
         elif num_outputs == 1:
             base_mod_env[next(iter(partition.outputs))] = output_val
 
+    if keep_original_order and not base_mod_env:
+        for node in m.graph.nodes:
+            base_mod_env, base_mod_attrs = construct_graph(
+                node, base_mod_env, base_mod_attrs
+            )
+
     for node in m.graph.nodes:
         if node.op == "output":
             base_mod_graph.output(

Versions

main

cc @ezyang @SherlockNoMad @EikanWang @jgong5 @wenzhe-nrv

Metadata

Metadata

Assignees

No one assigned

    Labels

    fxmodule: fx.passesOptimization passes written in FX (don't forget to select a more specific label)

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions