KEMBAR78
Improve concat fusion with matmuls · Issue #102804 · pytorch/pytorch · GitHub
Skip to content

Improve concat fusion with matmuls #102804

@Chillee

Description

@Chillee

🚀 The feature, motivation and pitch

Today, our concat fusion is somewhat patchy with matmuls in particular. This is since, for autotuning reasons, we force matmul into a fixedlayout, which prevents concat from automatically fusing it away.

Here's an example.

import torch

torch.set_default_device('cuda')

@torch.compile
def f(x, y):
    x = torch.mm(x, x)
    y = torch.cos(y)
    return torch.cat([x, y])

f(torch.randn(32, 32), torch.randn(32, 32))

which generates this with 3 kernels (should only need 2).

def call(args):
    arg0_1, arg1_1 = args
    args.clear()
    with torch.cuda._DeviceGuard(0):
        torch.cuda.set_device(0) # no-op to ensure context
        buf0 = empty_strided((32, 32), (32, 1), device='cuda', dtype=torch.float32)
        extern_kernels.mm(arg0_1, arg0_1, out=buf0)
        del arg0_1
        buf3 = empty_strided((64, 32), (32, 1), device='cuda', dtype=torch.float32)
        buf1 = as_strided(buf3, (32, 32), (32, 1))  # alias
        stream0 = get_cuda_stream(0)
        triton_poi_fused_cat_0.run(buf0, buf1, 1024, grid=grid(1024), stream=stream0)
        del buf0
        buf2 = as_strided(buf3, (32, 32), (32, 1), 1024)  # alias
        triton_poi_fused_cos_1.run(arg1_1, buf2, 1024, grid=grid(1024), stream=stream0)
        del arg1_1
        return (buf3, )

The idea here is that if we're not autotuning, we can just allow the matmul to keep a flexible layout: https://github.com/pytorch/pytorch/blob/main/torch/_inductor/kernel/mm_common.py#L135

Alternatives

No response

Additional context

No response

cc @ezyang @msaroufim @wconstab @bdhirsh @anijain2305 @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @Xia-Weiwen @wenzhe-nrv @jiayisunx @peterbell10 @ipiszy @ngimel @yf225

Metadata

Metadata

Assignees

Labels

featureA request for a proper, new feature.module: inductoroncall: pt2topic: performancetopic categorytriagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate module

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions