-
Notifications
You must be signed in to change notification settings - Fork 25.7k
Description
🚀 The feature, motivation and pitch
import torch
from triton.testing import do_bench
torch.set_default_device('cuda')
def down_size(size):
assert size[-1] % 2 == 0, f"{size} last dim not divisible by two"
return (*size[:-1], size[-1] // 2)
def up_size(size):
return (*size[:-1], size[-1] * 2)
def unpack_uint4(uint8_data) -> torch.Tensor:
"""Get the original weight from the normalized float weight format"""
shape = uint8_data.shape
# since we are using uint8 we will decode 2 entries per byte
# Shift elements down 4 and select out the bottom 4 bits
#
# Note: known slow with triton
# * currently generates two kernels with a cat in between
# * after https://github.com/pytorch/pytorch/pull/123278 lands I
# verified that we get a single triton kernel, but that is even slower
# than the two kernels before this PR
# * TODO add a microbenchmark of just the cast and profile this
first_elements = (uint8_data >> 4).to(torch.uint8)
second_elements = (uint8_data & 0b1111).to(torch.uint8)
unpacked = torch.stack([first_elements, second_elements], dim=-1).view(up_size(shape))
# trying Bert Maher's suggestion
# 2024-04-04: this works in unit tests but is broken on LLaMa 7B FFN with
# ptxas /tmp/tmp84wp7lea.ptx, line 227; error : Unexpected instruction types specified for 'sub'
# which seems to be the same issue as https://github.com/pytorch/pytorch/issues/118589
# TODO(later): try removing subtractions from our cast to see if we can work around
# shift_tensor = torch.tensor([4, 0], dtype=torch.uint8, device=uint8_data.device)
# unpacked = (uint8_data.reshape(-1)[::, None] >> shift_tensor) & 0b1111
# unpacked = unpacked.view(up_size(shape))
return unpacked
@torch.compile
def f(x):
return unpack_uint4(x) * 2
inp = torch.zeros(2, 4096, 11008//2, dtype=torch.uint8)
out = f(inp)
tm = do_bench(lambda: f(inp))
print(tm)
print(((1e3/tm) * (inp.numel() * inp.dtype.itemsize + out.numel() * out.dtype.itemsize)) / 1e12)
We end up generating loads like this
tmp5 = tl.load(in_ptr0 + ((5504*x1) + (x0 // 2)), tmp4, eviction_policy='evict_last', other=0.0)
This seems to be some heuristic that occurs during lowering (when we are much more flexible in changing iteration order). In fact, if we modify this line (https://github.com/pytorch/pytorch/blob/main/torch/_inductor/lowering.py#L1246) to
out = pointwise_cat(inputs, dim)
out.realize()
this boosts our performance on this kernel by 50%.
In the broader kernel that @vkuzo is looking at, it boosts our performance by 3x.
cc: @eellison on pointwise cat
cc: @shunting314 on layouts
Alternatives
No response
Additional context
No response
cc @ezyang @msaroufim @bdhirsh @anijain2305 @chauhang @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @peterbell10 @ipiszy @yf225 @chenyang78 @kadeng @muchulee8 @aakhundov @ColinPeppler @amjames @desertfire