KEMBAR78
[ONNX] How exporter handles higher order ops (HOP) · Issue #140995 · pytorch/pytorch · GitHub
Skip to content

[ONNX] How exporter handles higher order ops (HOP) #140995

@justinchuby

Description

@justinchuby

Note

This is a design doc for HOP support in the ONNX exporter.

In PyTorch export IR, control flows and special ops like cond, scan and wrap_with_autocast are represented as higher order ops (HOP) which take functions (represented as local GraphModules) as inputs.

As the functions called by the HOPs are pure and do not close over values from outer naming scopes, they are different from typical ONNX subgraphs (that can reference outer scoped values). To allow constructing arguments to an HOP without any information from the outer scope, it is easier and more straightforward to represent the subgraphs as ONNX functions than as ONNX graphs.

Nested initializers

The GraphModules in the export IR are pure. All initializers are provided as inputs. So there will be no nested initializers.

Nested subgraphs

The GraphModules can be nested, meaning a local GraphModule can in turn have an HOP that takes another GraphModule, local to the submodule. These nested GraphModules have unique names in there own Python naming scopes, but they can have conflicting names if we move them to the same scope. This is important because there is only one naming scope for ONNX model local functions. If we translate all GraphModules in the exported program and list them to the same scope, we need to ensure names do not collide and can be referenced correctly.

As a concrete example, consider the following model, where a branch of the cond operator is in turn calling another cond operator:

        class Submodule(torch.nn.Module):
            def __init__(self):
                super().__init__()
                # Nested weight
                self.weight = torch.nn.Parameter(torch.tensor([100.0]))

            def forward(self, x):
                def true_fn(x):
                    return x * self.weight

                def false_fn(x):
                    return x / self.weight

                y = torch.cond(x.sum() <= 0, true_fn, false_fn, [x])
                return y

        class CondModel(torch.nn.Module):
            def __init__(self):
                super().__init__()
                self.submodule = Submodule()
                self.weight = torch.nn.Parameter(torch.tensor([42.0]))

            def forward(self, x):
                def true_fn(x):
                    return self.submodule(x)

                def false_fn(x):
                    return x - self.weight

                y = torch.cond(x.sum() > 0, true_fn, false_fn, [x])
                return y

The exported program looks like this:

ExportedProgram:
    class GraphModule(torch.nn.Module):
        def forward(self, p_weight: "f32[1]", p_submodule_weight: "f32[1]", x: "i64[2]"):
            sum_1: "i64[]" = torch.ops.aten.sum.default(x)
            gt: "b8[]" = torch.ops.aten.gt.Scalar(sum_1, 0);  sum_1 = None
            
            true_graph_0 = self.true_graph_0
            false_graph_0 = self.false_graph_0
            cond = torch.ops.higher_order.cond(gt, true_graph_0, false_graph_0, [x, p_submodule_weight, p_weight]);  gt = true_graph_0 = false_graph_0 = x = p_submodule_weight = p_weight = None
            getitem: "f32[2]" = cond[0];  cond = None
            return (getitem,)
            
        class true_graph_0(torch.nn.Module):
            def forward(self, x: "i64[2]", p_submodule_weight: "f32[1]", p_weight: "f32[1]"):
                 # File: <eval_with_key>.50:6 in forward, code: sum_1 = l_args_3_0__1.sum()
                sum_1: "i64[]" = torch.ops.aten.sum.default(x)
                
                 # File: <eval_with_key>.50:7 in forward, code: le = sum_1.le(0);  sum_1 = None
                le: "b8[]" = torch.ops.aten.le.Scalar(sum_1, 0);  sum_1 = None
                
                true_graph_0 = self.true_graph_0
                false_graph_0 = self.false_graph_0
                cond = torch.ops.higher_order.cond(le, true_graph_0, false_graph_0, [p_submodule_weight, x]);  le = true_graph_0 = false_graph_0 = p_submodule_weight = x = None
                getitem: "f32[2]" = cond[0];  cond = None
                return (getitem,)
                
            class true_graph_0(torch.nn.Module):
                def forward(self, p_submodule_weight: "f32[1]", x: "i64[2]"):
                    mul: "f32[2]" = torch.ops.aten.mul.Tensor(x, p_submodule_weight);  x = p_submodule_weight = None
                    return (mul,)
                    
            class false_graph_0(torch.nn.Module):
                def forward(self, p_submodule_weight: "f32[1]", x: "i64[2]"):
                    div: "f32[2]" = torch.ops.aten.div.Tensor(x, p_submodule_weight);  x = p_submodule_weight = None
                    return (div,)
                    
        class false_graph_0(torch.nn.Module):
            def forward(self, x: "i64[2]", p_submodule_weight: "f32[1]", p_weight: "f32[1]"):
                sub: "f32[2]" = torch.ops.aten.sub.Tensor(x, p_weight);  x = p_weight = None
                return (sub,)

The root graph defines true_graph_0, which defines another true_graph_0 in its own naming scope. Using graph_module.named_modules(), we get a list of modules, including the nested ones: ["", true_graph_0, true_graph_0.true_graph_0, true_graph_0.false_graph_0, false_graph_0].

Note that the names provided by name_modules() are already scoped. We can recover the local object name and the naming scope it’s in by simply splitting on the last dot in the string.

Each of these graphs will have a value naming scope that has unique names for its immediate subgraphs and values:

{
    "": true_graph_0, false_graph_0, <values>
    "true_graph_0": true_graph_0, false_graph_0, <values>
    "true_graph_0.true_graph_0": <values>
    "true_graph_0.false_graph_0": <values>
    "false_graph_0": <values>
}

We can thus construct an ONNX function for each of the non-root Graphs in the reversed order, such that the inner most graph is first constructed and made available in the naming scope of the outer graph before the outer graph is constructed.

This motivates the following data structure for storing scoped graphs: dict[<scope name>, dict[<graph name in scope>, <ir.Graph>]].

The way we fill in the data structure is the following:

scoped_subgraphs = dict()  # some defaultdict
for (module_name, module) in reversed(named_modules()):
    If module_name == "":
        break
    onnx_function = translate(module, scoped_graphs[module_name])
    parent_scope, local_module_name = module_name.rsplit(".", 1)
    scoped_graphs[parent_scope][local_module_name] = onnx_function
# Translate the top graph
top_graph = translate_top_graph(ep.graph, scoped_graphs[""])
# Finally construct the model and collect all onnx functions to the model

Within translate, all values are created as inputs to the module or by nodes in the module. Any get_attr node will obtain the corresponding ONNX function from the provided scoped_graph. From the ONNX function it is possible to create a single node that calls the function. Since we know the number of outputs the function has, we can also handle multi-output cases. (This assumes the output is not variadic).

Implementing specific operators

Using cond as an example, the dispatcher dispatches to an implementation of

@traced
def cond_impl(cond: ir.Value, true_func: ir.Function, false_func, inputs: list[ir.Value]):
    # build_node is a helper that creates a node that calls the given function
    return op.If(conf, then_graph=ir.Graph(…, [build_node(true_func, inputs)]), else_graph=ir.Graph(..., [build_node(false_func, inputs)], num_outputs=len(true_func.outputs))

Similarly, autocast can be implemented as

@traced
def autocast_impl(device, _arg, _arg2, _arg3, func: ir.Function, *args):
    # Add custom casting logic for the device, potentially modifying the graph in func
    return call_func_as_onnxscript_func(func, *args)

This shows the mechanism for supporting various HOPs is general, and we can assume whenever we have a get_attr node as input, it is already translated into an ONNX IR function.

Metadata

Metadata

Labels

module: onnxRelated to torch.onnxtriagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate module

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions