KEMBAR78
Lowering after pointwise cat can lead to uncontiguous memory accesses · Issue #124002 · pytorch/pytorch · GitHub
Skip to content

Lowering after pointwise cat can lead to uncontiguous memory accesses #124002

@Chillee

Description

@Chillee

🚀 The feature, motivation and pitch

import torch
from triton.testing import do_bench
torch.set_default_device('cuda')

def down_size(size):
    assert size[-1] % 2 == 0, f"{size} last dim not divisible by two"
    return (*size[:-1], size[-1] // 2)

def up_size(size):
    return (*size[:-1], size[-1] * 2)

def unpack_uint4(uint8_data) -> torch.Tensor:
    """Get the original weight from the normalized float weight format"""
    shape = uint8_data.shape

    # since we are using uint8 we will decode 2 entries per byte
    # Shift elements down 4 and select out the bottom 4 bits
    #
    # Note: known slow with triton
    # * currently generates two kernels with a cat in between
    # * after https://github.com/pytorch/pytorch/pull/123278 lands I
    #   verified that we get a single triton kernel, but that is even slower
    #   than the two kernels before this PR
    # * TODO add a microbenchmark of just the cast and profile this
    first_elements = (uint8_data >> 4).to(torch.uint8)
    second_elements = (uint8_data & 0b1111).to(torch.uint8)
    unpacked = torch.stack([first_elements, second_elements], dim=-1).view(up_size(shape))

    # trying Bert Maher's suggestion
    # 2024-04-04: this works in unit tests but is broken on LLaMa 7B FFN with
    #   ptxas /tmp/tmp84wp7lea.ptx, line 227; error   : Unexpected instruction types specified for 'sub'
    # which seems to be the same issue as https://github.com/pytorch/pytorch/issues/118589
    # TODO(later): try removing subtractions from our cast to see if we can work around
    # shift_tensor = torch.tensor([4, 0], dtype=torch.uint8, device=uint8_data.device)
    # unpacked = (uint8_data.reshape(-1)[::, None] >> shift_tensor) & 0b1111
    # unpacked = unpacked.view(up_size(shape))

    return unpacked

@torch.compile
def f(x):
    return unpack_uint4(x) * 2

inp = torch.zeros(2, 4096, 11008//2, dtype=torch.uint8)
out = f(inp)

tm = do_bench(lambda: f(inp))
print(tm)
print(((1e3/tm) * (inp.numel() * inp.dtype.itemsize + out.numel() * out.dtype.itemsize)) / 1e12)

We end up generating loads like this

    tmp5 = tl.load(in_ptr0 + ((5504*x1) + (x0 // 2)), tmp4, eviction_policy='evict_last', other=0.0)

This seems to be some heuristic that occurs during lowering (when we are much more flexible in changing iteration order). In fact, if we modify this line (https://github.com/pytorch/pytorch/blob/main/torch/_inductor/lowering.py#L1246) to

            out = pointwise_cat(inputs, dim)
            out.realize()

this boosts our performance on this kernel by 50%.

In the broader kernel that @vkuzo is looking at, it boosts our performance by 3x.

cc: @eellison on pointwise cat
cc: @shunting314 on layouts

Alternatives

No response

Additional context

No response

cc @ezyang @msaroufim @bdhirsh @anijain2305 @chauhang @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @peterbell10 @ipiszy @yf225 @chenyang78 @kadeng @muchulee8 @aakhundov @ColinPeppler @amjames @desertfire

Metadata

Metadata

Assignees

Labels

module: inductoroncall: pt2triagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate module

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions