KEMBAR78
max_pool2d CPU forward performance is poor · Issue #51393 · pytorch/pytorch · GitHub
Skip to content

max_pool2d CPU forward performance is poor #51393

@jamesr66a

Description

@jamesr66a
import torch
import torch.fx

import torchvision.models as models
rn18 = models.resnet18()
rn18.eval()
rn18.requires_grad_(False)

N, C, H, W = 10, 3, 224, 224

def rn18_bench(input, chrome_trace_filename):
    with torch.autograd.profiler.profile(record_shapes=True) as prof:
        rn18(x)
    prof.export_chrome_trace(chrome_trace_filename)

x = torch.randn(N, C, H, W, requires_grad=False)
rn18_bench(x, 'maxpool_nchw.json')

x_nhwc = x.contiguous(memory_format=torch.channels_last)
rn18_bench(x_nhwc, 'maxpool_nhwc.json')

# Roofline analyze maxpool


def maxpool_bench(input, name):
    import time

    # warmup iteration. don't care if things get put in cache because we're not
    # even close to hitting bandwidth/cache bound
    rn18.maxpool(input)

    NITER = 100
    s = time.time()
    for _ in range(NITER):
        out = rn18.maxpool(input)
    e = time.time()

    time_per_iter_sec = (e - s) / NITER
    bytes_in = max_pool_input.numel() * max_pool_input.element_size()
    bytes_out = out.numel() * out.element_size()
    gbps = (bytes_in + bytes_out) / time_per_iter_sec / 1e9

    total_kernel_size = rn18.maxpool.kernel_size ** 2
    gflops = out.numel() * total_kernel_size / time_per_iter_sec / 1e9

    print(name, gbps, 'GB/s', gflops, 'GFLOP/s')

max_pool_input = torch.randn(10, 64, 112, 112, requires_grad=False)
maxpool_bench(max_pool_input, 'maxpool NCHW')

max_pool_input_nhwc = max_pool_input.contiguous(memory_format=torch.channels_last)
maxpool_bench(max_pool_input_nhwc, 'maxpool NHWC')

NCHW
image

NHWC
image

Results from running on my machine (lscpu)

maxpool NCHW 4.765286484398438 GB/s 2.1443789179792976 GFLOP/s
maxpool NHWC 3.958536155344408 GB/s 1.7813412699049835 GFLOP/s

These results are well below both peak memory/cache bandwidth on the machine and GFLOPs

perf indicates that we're hitting max_pool2d_with_indices_single_out_frame.

A few things pop out here:

  • We're unconditionally pulling out indices for backwards even if we're running this network in inference
  • There's no way this can be vectorized. In fact, disassembly indicates that it's not (ucomiss on scalar single-precision floats):
    image
  • No difference for channels-last memory layout. Making this fast for NCHW is understandably hard, but NHWC should be relatively trivial. The quantized op kernel does so:

cc @VitalyFedyunin @ngimel @heitorschueroff

Metadata

Metadata

Assignees

No one assigned

    Labels

    module: performanceIssues related to performance, either of kernel code or framework gluemodule: poolingtriagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate module

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions