KEMBAR78
❓ [Question] operator being decomposed rather than being converted when a corresponding converter exists? · Issue #2665 · pytorch/TensorRT · GitHub
Skip to content

❓ [Question] operator being decomposed rather than being converted when a corresponding converter exists? #2665

@HolyWu

Description

@HolyWu

❓ Question

From the debug log below, it seems that the aten.grid_sampler_2d operator gets decomposed into several lower-level operators. But isn't there a corresponding converter which should be used?

What you have already tried

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch_tensorrt


class MyModule(nn.Module):
    def __init__(self):
        super().__init__()
        
    def forward(self, input, grid):
        return F.grid_sample(input, grid, mode="bilinear", padding_mode="border", align_corners=True)
    
model = MyModule().eval().cuda()

inputs = [
    torch.randn((1, 3, 8, 8), dtype=torch.float, device="cuda"),
    torch.randn((1, 16, 16, 2), dtype=torch.float, device="cuda")
]

optimized_model = torch_tensorrt.compile(
    model,
    ir="dynamo",
    inputs=inputs,
    enabled_precisions={torch.float},
    debug=True,
    min_block_size=1,
    truncate_long_and_double=True,
    output_format="fx",
)
DEBUG:torch_tensorrt.dynamo.conversion.aten_ops_converters:_to_copy converter rejected node _to_copy with dtype torch.int64
DEBUG:torch_tensorrt.dynamo.conversion.aten_ops_converters:_to_copy converter rejected node _to_copy with dtype torch.int64
DEBUG:torch_tensorrt.dynamo.conversion.aten_ops_converters:_to_copy converter rejected node _to_copy_1 with dtype torch.int64
DEBUG:torch_tensorrt.dynamo.conversion.aten_ops_converters:_to_copy converter rejected node _to_copy_1 with dtype torch.int64
DEBUG:torch_tensorrt.dynamo.conversion.aten_ops_converters:_to_copy converter rejected node _to_copy_2 with dtype torch.int64
DEBUG:torch_tensorrt.dynamo.conversion.aten_ops_converters:_to_copy converter rejected node _to_copy_2 with dtype torch.int64
DEBUG:torch_tensorrt.dynamo.conversion.aten_ops_converters:_to_copy converter rejected node _to_copy_3 with dtype torch.int64
DEBUG:torch_tensorrt.dynamo.conversion.aten_ops_converters:_to_copy converter rejected node _to_copy_3 with dtype torch.int64
DEBUG:torch_tensorrt.dynamo.conversion.aten_ops_converters:_to_copy converter rejected node _to_copy_4 with dtype torch.int64
DEBUG:torch_tensorrt.dynamo.conversion.aten_ops_converters:_to_copy converter rejected node _to_copy_4 with dtype torch.int64
DEBUG:torch_tensorrt.dynamo.conversion.aten_ops_converters:_to_copy converter rejected node _to_copy_5 with dtype torch.int64
DEBUG:torch_tensorrt.dynamo.conversion.aten_ops_converters:_to_copy converter rejected node _to_copy_5 with dtype torch.int64
DEBUG:torch_tensorrt.dynamo.conversion.aten_ops_converters:_to_copy converter rejected node _to_copy_6 with dtype torch.int64
DEBUG:torch_tensorrt.dynamo.conversion.aten_ops_converters:_to_copy converter rejected node _to_copy_6 with dtype torch.int64
DEBUG:torch_tensorrt.dynamo.conversion.aten_ops_converters:_to_copy converter rejected node _to_copy_7 with dtype torch.int64
DEBUG:torch_tensorrt.dynamo.conversion.aten_ops_converters:_to_copy converter rejected node _to_copy_7 with dtype torch.int64
DEBUG:torch_tensorrt.dynamo.partitioning._global_partitioner:
Supported Nodes:
- torch.ops.aten.reshape.default + Operator Count: 13
- torch.ops.aten.expand.default + Operator Count: 1
- torch.ops.aten.select.int + Operator Count: 2
- torch.ops.aten.mul.Tensor + Operator Count: 10
- torch.ops.aten.add.Tensor + Operator Count: 7
- torch.ops.aten.clamp.default + Operator Count: 2
- torch.ops.aten.floor.default + Operator Count: 2
- torch.ops.aten.sub.Tensor + Operator Count: 8
- torch.ops.aten.ge.Scalar + Operator Count: 8
- torch.ops.aten.lt.Scalar + Operator Count: 8
- torch.ops.aten.logical_and.default + Operator Count: 12
- torch.ops.aten.where.self + Operator Count: 12
- torch.ops.aten.index.Tensor + Operator Count: 4

DEBUG:torch_tensorrt.dynamo.partitioning._global_partitioner:
Unsupported or Excluded Nodes:
- torch.ops.aten._to_copy.default + Operator Count: 8

DEBUG:torch_tensorrt.dynamo._compiler:Detected support for 89 operators out of 97 in subgraph.
DEBUG:torch_tensorrt.dynamo.conversion.aten_ops_converters:_to_copy converter rejected node _to_copy with dtype torch.int64
DEBUG:torch_tensorrt.dynamo.conversion.aten_ops_converters:_to_copy converter rejected node _to_copy with dtype torch.int64
DEBUG:torch_tensorrt.dynamo.conversion.aten_ops_converters:_to_copy converter rejected node _to_copy_1 with dtype torch.int64
DEBUG:torch_tensorrt.dynamo.conversion.aten_ops_converters:_to_copy converter rejected node _to_copy_1 with dtype torch.int64
DEBUG:torch_tensorrt.dynamo.conversion.aten_ops_converters:_to_copy converter rejected node _to_copy_2 with dtype torch.int64
DEBUG:torch_tensorrt.dynamo.conversion.aten_ops_converters:_to_copy converter rejected node _to_copy_2 with dtype torch.int64
DEBUG:torch_tensorrt.dynamo.conversion.aten_ops_converters:_to_copy converter rejected node _to_copy_3 with dtype torch.int64
DEBUG:torch_tensorrt.dynamo.conversion.aten_ops_converters:_to_copy converter rejected node _to_copy_3 with dtype torch.int64
DEBUG:torch_tensorrt.dynamo.conversion.aten_ops_converters:_to_copy converter rejected node _to_copy_4 with dtype torch.int64
DEBUG:torch_tensorrt.dynamo.conversion.aten_ops_converters:_to_copy converter rejected node _to_copy_4 with dtype torch.int64
DEBUG:torch_tensorrt.dynamo.conversion.aten_ops_converters:_to_copy converter rejected node _to_copy_5 with dtype torch.int64
DEBUG:torch_tensorrt.dynamo.conversion.aten_ops_converters:_to_copy converter rejected node _to_copy_5 with dtype torch.int64
DEBUG:torch_tensorrt.dynamo.conversion.aten_ops_converters:_to_copy converter rejected node _to_copy_6 with dtype torch.int64
DEBUG:torch_tensorrt.dynamo.conversion.aten_ops_converters:_to_copy converter rejected node _to_copy_6 with dtype torch.int64
DEBUG:torch_tensorrt.dynamo.conversion.aten_ops_converters:_to_copy converter rejected node _to_copy_7 with dtype torch.int64
DEBUG:torch_tensorrt.dynamo.conversion.aten_ops_converters:_to_copy converter rejected node _to_copy_7 with dtype torch.int64
DEBUG:torch_tensorrt.dynamo.partitioning._adjacency_partitioner:
Number of TensorRT-Accelerated Engines Generated: 2
DEBUG:torch_tensorrt.dynamo.partitioning._adjacency_partitioner:
Supported Nodes:
- torch.ops.aten.reshape.default + Operator Count: 13
- torch.ops.aten.expand.default + Operator Count: 1
- torch.ops.aten.select.int + Operator Count: 2
- torch.ops.aten.mul.Tensor + Operator Count: 10
- torch.ops.aten.add.Tensor + Operator Count: 7
- torch.ops.aten.clamp.default + Operator Count: 2
- torch.ops.aten.floor.default + Operator Count: 2
- torch.ops.aten.sub.Tensor + Operator Count: 8
- torch.ops.aten.ge.Scalar + Operator Count: 8
- torch.ops.aten.lt.Scalar + Operator Count: 8
- torch.ops.aten.logical_and.default + Operator Count: 12
- torch.ops.aten.where.self + Operator Count: 12
- torch.ops.aten.index.Tensor + Operator Count: 4

DEBUG:torch_tensorrt.dynamo.partitioning._adjacency_partitioner:
Unsupported or Excluded Nodes:
- torch.ops.aten._to_copy.default + Operator Count: 8

++++++++++++++++++++++++++++++++++++++++++++++++++ Dry-Run Results for Graph ++++++++++++++++++++++++++++++++++++++++++++++++++

The graph consists of 97 Total Operators, of which 89 operators are supported, 91.75% coverage

The following ops are currently unsupported or excluded from conversion, and are listed with their op-count in the graph:
 torch.ops.aten._to_copy.default: 8

The following nodes are currently set to run in Torch:
Node: torch.ops.aten._to_copy.default, with layer location: __/_to_copy
Node: torch.ops.aten._to_copy.default, with layer location: __/_to_copy_1
Node: torch.ops.aten._to_copy.default, with layer location: __/_to_copy_2
Node: torch.ops.aten._to_copy.default, with layer location: __/_to_copy_3
Node: torch.ops.aten._to_copy.default, with layer location: __/_to_copy_4
Node: torch.ops.aten._to_copy.default, with layer location: __/_to_copy_5
Node: torch.ops.aten._to_copy.default, with layer location: __/_to_copy_6
Node: torch.ops.aten._to_copy.default, with layer location: __/_to_copy_7
Note: Some of the above nodes may be supported, but were not included in a TRT graph by the partitioner

Compiled with: CompilationSettings(precision=torch.float32, debug=True, workspace_size=0, min_block_size=1, torch_executed_ops=set(), pass_through_build_failures=False, max_aux_streams=None, version_compatible=False, optimization_level=None, use_python_runtime=False, truncate_long_and_double=True, use_fast_partitioner=True, enable_experimental_decompositions=False, device=Device(type=DeviceType.GPU, gpu_id=0), require_full_compilation=False, disable_tf32=False, sparse_weights=False, refit=False, engine_capability=<EngineCapability.DEFAULT: 0>, num_avg_timing_iters=1, dla_sram_size=1048576, dla_local_dram_size=1073741824, dla_global_dram_size=536870912, dryrun=False, hardware_compatible=False, output_format='fx')

  Graph Structure:

   Inputs: List[Tensor: (1, 3, 8, 8)@float32, Tensor: (1, 16, 16, 2)@float32]
    ...
    TRT Engine #1 - Submodule name: _run_on_acc_0
     Engine Inputs: List[Tensor: (1, 16, 16, 2)@float32]
     Number of Operators in Engine: 58
     Engine Outputs: Tuple(Tensor: (1, 3, 16, 16)@float32, Tensor: (1, 3, 16, 16)@float32, Tensor: (1, 3, 16, 16)@bool, Tensor: (1, 3, 16, 16)@float32, Tensor: (1, 3, 16, 16)@float32, Tensor: (1, 3, 16, 16)@bool, Tensor: (1, 3, 16, 16)@float32, Tensor: (1, 3, 16, 16)@float32, Tensor: (1, 3, 16, 16)@bool, Tensor: (1, 3, 16, 16)@float32, Tensor: (1, 3, 16, 16)@bool, Tensor: (1, 3, 16, 16)@float32)
    ...
    TRT Engine #2 - Submodule name: _run_on_acc_2
     Engine Inputs: List[Tensor: (1, 3, 16, 16)@bool, Tensor: (1, 3, 16, 16)@int32, Tensor: (1, 3, 16, 16)@int32, Tensor: (1, 3, 8, 8)@float32, Tensor: (1, 3, 16, 16)@float32, Tensor: (1, 3, 16, 16)@bool, Tensor: (1, 3, 16, 16)@int32, Tensor: (1, 3, 16, 16)@int32, Tensor: (1, 3, 16, 16)@float32, Tensor: (1, 3, 16, 16)@bool, Tensor: (1, 3, 16, 16)@int32, Tensor: (1, 3, 16, 16)@int32, Tensor: (1, 3, 16, 16)@float32, Tensor: (1, 3, 16, 16)@bool, Tensor: (1, 3, 16, 16)@int32, Tensor: (1, 3, 16, 16)@int32, Tensor: (1, 3, 16, 16)@float32]
     Number of Operators in Engine: 31
     Engine Outputs: Tensor: (1, 3, 16, 16)@float32
    ...
   Outputs: List[Tensor: (1, 3, 16, 16)@float32]

  ------------------------- Aggregate Stats -------------------------

   Average Number of Operators per TRT Engine: 44.5
   Most Operators in a TRT Engine: 58

  ********** Recommendations **********

   - For minimal graph segmentation, select min_block_size=58 which would generate 1 TRT engine(s)
   - For moderate graph segmentation, select min_block_size=45 which would generate 1 TRT engine(s)
   - The current level of graph segmentation is equivalent to selecting min_block_size=31 which generates 2 TRT engine(s)

Environment

  • PyTorch Version (e.g., 1.0): 2.3.0.dev20240221+cu121
  • CPU Architecture: x64
  • OS (e.g., Linux): Ubuntu 22.04 LTS
  • How you installed PyTorch (conda, pip, libtorch, source): pip
  • Build command you used (if compiling from source):
  • Are you using local sources or building from archives:
  • Python version: 3.10.12
  • CUDA version: 12.1
  • GPU models and configuration: RTX 3050
  • Any other relevant information:

Metadata

Metadata

Assignees

Labels

questionFurther information is requested

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions