-
Notifications
You must be signed in to change notification settings - Fork 25.7k
Description
Note
This is a design doc for HOP support in the ONNX exporter.
In PyTorch export IR, control flows and special ops like cond, scan and wrap_with_autocast are represented as higher order ops (HOP) which take functions (represented as local GraphModules) as inputs.
As the functions called by the HOPs are pure and do not close over values from outer naming scopes, they are different from typical ONNX subgraphs (that can reference outer scoped values). To allow constructing arguments to an HOP without any information from the outer scope, it is easier and more straightforward to represent the subgraphs as ONNX functions than as ONNX graphs.
Nested initializers
The GraphModules in the export IR are pure. All initializers are provided as inputs. So there will be no nested initializers.
Nested subgraphs
The GraphModules can be nested, meaning a local GraphModule can in turn have an HOP that takes another GraphModule, local to the submodule. These nested GraphModules have unique names in there own Python naming scopes, but they can have conflicting names if we move them to the same scope. This is important because there is only one naming scope for ONNX model local functions. If we translate all GraphModules in the exported program and list them to the same scope, we need to ensure names do not collide and can be referenced correctly.
As a concrete example, consider the following model, where a branch of the cond operator is in turn calling another cond operator:
class Submodule(torch.nn.Module):
def __init__(self):
super().__init__()
# Nested weight
self.weight = torch.nn.Parameter(torch.tensor([100.0]))
def forward(self, x):
def true_fn(x):
return x * self.weight
def false_fn(x):
return x / self.weight
y = torch.cond(x.sum() <= 0, true_fn, false_fn, [x])
return y
class CondModel(torch.nn.Module):
def __init__(self):
super().__init__()
self.submodule = Submodule()
self.weight = torch.nn.Parameter(torch.tensor([42.0]))
def forward(self, x):
def true_fn(x):
return self.submodule(x)
def false_fn(x):
return x - self.weight
y = torch.cond(x.sum() > 0, true_fn, false_fn, [x])
return yThe exported program looks like this:
ExportedProgram:
class GraphModule(torch.nn.Module):
def forward(self, p_weight: "f32[1]", p_submodule_weight: "f32[1]", x: "i64[2]"):
sum_1: "i64[]" = torch.ops.aten.sum.default(x)
gt: "b8[]" = torch.ops.aten.gt.Scalar(sum_1, 0); sum_1 = None
true_graph_0 = self.true_graph_0
false_graph_0 = self.false_graph_0
cond = torch.ops.higher_order.cond(gt, true_graph_0, false_graph_0, [x, p_submodule_weight, p_weight]); gt = true_graph_0 = false_graph_0 = x = p_submodule_weight = p_weight = None
getitem: "f32[2]" = cond[0]; cond = None
return (getitem,)
class true_graph_0(torch.nn.Module):
def forward(self, x: "i64[2]", p_submodule_weight: "f32[1]", p_weight: "f32[1]"):
# File: <eval_with_key>.50:6 in forward, code: sum_1 = l_args_3_0__1.sum()
sum_1: "i64[]" = torch.ops.aten.sum.default(x)
# File: <eval_with_key>.50:7 in forward, code: le = sum_1.le(0); sum_1 = None
le: "b8[]" = torch.ops.aten.le.Scalar(sum_1, 0); sum_1 = None
true_graph_0 = self.true_graph_0
false_graph_0 = self.false_graph_0
cond = torch.ops.higher_order.cond(le, true_graph_0, false_graph_0, [p_submodule_weight, x]); le = true_graph_0 = false_graph_0 = p_submodule_weight = x = None
getitem: "f32[2]" = cond[0]; cond = None
return (getitem,)
class true_graph_0(torch.nn.Module):
def forward(self, p_submodule_weight: "f32[1]", x: "i64[2]"):
mul: "f32[2]" = torch.ops.aten.mul.Tensor(x, p_submodule_weight); x = p_submodule_weight = None
return (mul,)
class false_graph_0(torch.nn.Module):
def forward(self, p_submodule_weight: "f32[1]", x: "i64[2]"):
div: "f32[2]" = torch.ops.aten.div.Tensor(x, p_submodule_weight); x = p_submodule_weight = None
return (div,)
class false_graph_0(torch.nn.Module):
def forward(self, x: "i64[2]", p_submodule_weight: "f32[1]", p_weight: "f32[1]"):
sub: "f32[2]" = torch.ops.aten.sub.Tensor(x, p_weight); x = p_weight = None
return (sub,)The root graph defines true_graph_0, which defines another true_graph_0 in its own naming scope. Using graph_module.named_modules(), we get a list of modules, including the nested ones: ["", true_graph_0, true_graph_0.true_graph_0, true_graph_0.false_graph_0, false_graph_0].
Note that the names provided by name_modules() are already scoped. We can recover the local object name and the naming scope it’s in by simply splitting on the last dot in the string.
Each of these graphs will have a value naming scope that has unique names for its immediate subgraphs and values:
{
"": true_graph_0, false_graph_0, <…values>
"true_graph_0": true_graph_0, false_graph_0, <…values>
"true_graph_0.true_graph_0": <…values>
"true_graph_0.false_graph_0": <…values>
"false_graph_0": <…values>
}We can thus construct an ONNX function for each of the non-root Graphs in the reversed order, such that the inner most graph is first constructed and made available in the naming scope of the outer graph before the outer graph is constructed.
This motivates the following data structure for storing scoped graphs: dict[<scope name>, dict[<graph name in scope>, <ir.Graph>]].
The way we fill in the data structure is the following:
scoped_subgraphs = dict() # some defaultdict
for (module_name, module) in reversed(named_modules()):
If module_name == "":
break
onnx_function = translate(module, scoped_graphs[module_name])
parent_scope, local_module_name = module_name.rsplit(".", 1)
scoped_graphs[parent_scope][local_module_name] = onnx_function
# Translate the top graph
top_graph = translate_top_graph(ep.graph, scoped_graphs[""])
# Finally construct the model and collect all onnx functions to the model
…Within translate, all values are created as inputs to the module or by nodes in the module. Any get_attr node will obtain the corresponding ONNX function from the provided scoped_graph. From the ONNX function it is possible to create a single node that calls the function. Since we know the number of outputs the function has, we can also handle multi-output cases. (This assumes the output is not variadic).
Implementing specific operators
Using cond as an example, the dispatcher dispatches to an implementation of
@traced
def cond_impl(cond: ir.Value, true_func: ir.Function, false_func, inputs: list[ir.Value]):
# build_node is a helper that creates a node that calls the given function
return op.If(conf, then_graph=ir.Graph(…, [build_node(true_func, inputs)]), else_graph=ir.Graph(..., [build_node(false_func, inputs)], num_outputs=len(true_func.outputs))Similarly, autocast can be implemented as
@traced
def autocast_impl(device, _arg, _arg2, _arg3, func: ir.Function, *args):
# Add custom casting logic for the device, potentially modifying the graph in func
return call_func_as_onnxscript_func(func, *args)This shows the mechanism for supporting various HOPs is general, and we can assume whenever we have a get_attr node as input, it is already translated into an ONNX IR function.