-
Notifications
You must be signed in to change notification settings - Fork 25.7k
Closed
Labels
fxmodule: fx.passesOptimization passes written in FX (don't forget to select a more specific label)Optimization passes written in FX (don't forget to select a more specific label)
Description
🐛 Describe the bug
import torch
from torch.fx.experimental.proxy_tensor import make_fx
from torch.fx.passes.split_module import split_module
def fn(x):
return (x,)
g = make_fx(fn, tracing_mode="fake")(torch.randn(3, 3))
g.print_readable()
# `keep_original_order=False` works
# split_module(g, None, split_callback=lambda _ : 0, keep_original_order=False)
# This fails
split_module(g, None, split_callback=lambda _ : 0, keep_original_order=True)Error
Traceback (most recent call last):
File "/home/kkalambarkar/lightning-thunder/scratchpad/test_split_module.py", line 13, in <module>
split_module(g, None, split_callback=lambda _ : 0, keep_original_order=True)
File "/home/kkalambarkar/git/pytorch/torch/fx/passes/split_module.py", line 607, in split_module
torch.fx.graph.map_arg(node.args[0], lambda n: base_mod_env[n.name])
File "/home/kkalambarkar/git/pytorch/torch/fx/node.py", line 880, in map_arg
return map_aggregate(a, lambda x: fn(x) if isinstance(x, Node) else x)
File "/home/kkalambarkar/git/pytorch/torch/fx/node.py", line 889, in map_aggregate
t = tuple([map_aggregate(elem, fn) for elem in a])
File "/home/kkalambarkar/git/pytorch/torch/fx/node.py", line 889, in <listcomp>
t = tuple([map_aggregate(elem, fn) for elem in a])
File "/home/kkalambarkar/git/pytorch/torch/fx/node.py", line 906, in map_aggregate
return fn(a)
File "/home/kkalambarkar/git/pytorch/torch/fx/node.py", line 880, in <lambda>
return map_aggregate(a, lambda x: fn(x) if isinstance(x, Node) else x)
File "/home/kkalambarkar/git/pytorch/torch/fx/passes/split_module.py", line 607, in <lambda>
torch.fx.graph.map_arg(node.args[0], lambda n: base_mod_env[n.name])
KeyError: 'x_1'Following patch seems to do the trick (do this seem like a good fix?)
diff --git a/torch/fx/passes/split_module.py b/torch/fx/passes/split_module.py
index 0495a9520f6..2169300a922 100644
--- a/torch/fx/passes/split_module.py
+++ b/torch/fx/passes/split_module.py
@@ -601,6 +601,12 @@ def split_module(
elif num_outputs == 1:
base_mod_env[next(iter(partition.outputs))] = output_val
+ if keep_original_order and not base_mod_env:
+ for node in m.graph.nodes:
+ base_mod_env, base_mod_attrs = construct_graph(
+ node, base_mod_env, base_mod_attrs
+ )
+
for node in m.graph.nodes:
if node.op == "output":
base_mod_graph.output(Versions
main
Metadata
Metadata
Assignees
Labels
fxmodule: fx.passesOptimization passes written in FX (don't forget to select a more specific label)Optimization passes written in FX (don't forget to select a more specific label)