-
Notifications
You must be signed in to change notification settings - Fork 25.7k
Closed
Labels
featureA request for a proper, new feature.A request for a proper, new feature.module: inductoroncall: pt2topic: performancetopic categorytopic categorytriagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate moduleThis issue has been looked at a team member, and triaged and prioritized into an appropriate module
Description
🚀 The feature, motivation and pitch
Today, our concat fusion is somewhat patchy with matmuls in particular. This is since, for autotuning reasons, we force matmul into a fixedlayout, which prevents concat from automatically fusing it away.
Here's an example.
import torch
torch.set_default_device('cuda')
@torch.compile
def f(x, y):
x = torch.mm(x, x)
y = torch.cos(y)
return torch.cat([x, y])
f(torch.randn(32, 32), torch.randn(32, 32))
which generates this with 3 kernels (should only need 2).
def call(args):
arg0_1, arg1_1 = args
args.clear()
with torch.cuda._DeviceGuard(0):
torch.cuda.set_device(0) # no-op to ensure context
buf0 = empty_strided((32, 32), (32, 1), device='cuda', dtype=torch.float32)
extern_kernels.mm(arg0_1, arg0_1, out=buf0)
del arg0_1
buf3 = empty_strided((64, 32), (32, 1), device='cuda', dtype=torch.float32)
buf1 = as_strided(buf3, (32, 32), (32, 1)) # alias
stream0 = get_cuda_stream(0)
triton_poi_fused_cat_0.run(buf0, buf1, 1024, grid=grid(1024), stream=stream0)
del buf0
buf2 = as_strided(buf3, (32, 32), (32, 1), 1024) # alias
triton_poi_fused_cos_1.run(arg1_1, buf2, 1024, grid=grid(1024), stream=stream0)
del arg1_1
return (buf3, )
The idea here is that if we're not autotuning, we can just allow the matmul to keep a flexible layout: https://github.com/pytorch/pytorch/blob/main/torch/_inductor/kernel/mm_common.py#L135
Alternatives
No response
Additional context
No response
cc @ezyang @msaroufim @wconstab @bdhirsh @anijain2305 @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @Xia-Weiwen @wenzhe-nrv @jiayisunx @peterbell10 @ipiszy @ngimel @yf225
Metadata
Metadata
Assignees
Labels
featureA request for a proper, new feature.A request for a proper, new feature.module: inductoroncall: pt2topic: performancetopic categorytopic categorytriagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate moduleThis issue has been looked at a team member, and triaged and prioritized into an appropriate module