-
Notifications
You must be signed in to change notification settings - Fork 25.7k
Description
🚀 The feature, motivation and pitch
Hi, I am trying SDPA + NestedTensor on 2.1.0.dev20230724+cu118 + A100 in the inference setting with batch size > 1. I read that SDPA supports nestedtensors, and I would expect it to be faster due to the absence of padding. To me nestedtensor is just an abstraction over a big tensor with relevant start indexes, end indexes, kind of matching with the API in https://github.com/Dao-AILab/flash-attention (that uses indexing rather than padding to handle batches).
However, from my tests, using NestedTensor is not faster, and actually even slower than non-nested.
Script:
import torch
import random
batch_size = 64
head_dim = 128
num_heads = 32
max_sequence_len = 512
pad_percentage = 0.4
seq_len_list = [
int(max_sequence_len * (1 - random.gauss(pad_percentage, 0.01)))
for _ in range(batch_size)
]
seq_len_list[6] = max_sequence_len
print("seq_len_list", seq_len_list)
query = torch.rand(batch_size, num_heads, 1, head_dim, dtype=torch.float16).to("cuda")
query = torch.nested.nested_tensor(list(query))
key = torch.nested.nested_tensor(
[torch.rand(num_heads, seq_len, head_dim, dtype=torch.float16).to("cuda") for seq_len in seq_len_list]
)
value = torch.nested.nested_tensor(
[torch.rand(num_heads, seq_len, head_dim, dtype=torch.float16).to("cuda") for seq_len in seq_len_list]
)
n_runs = 50
with torch.no_grad():
with torch.backends.cuda.sdp_kernel(enable_math=False, enable_flash=True, enable_mem_efficient=False):
start_event = torch.cuda.Event(enable_timing=True)
end_event = torch.cuda.Event(enable_timing=True)
torch.cuda.empty_cache()
torch.cuda.synchronize()
start_event.record()
for _ in range(n_runs):
res = torch.nn.functional.scaled_dot_product_attention(query, key, value)
end_event.record()
torch.cuda.synchronize()
tps = start_event.elapsed_time(end_event) / n_runs
print("SDPA nested:", tps, "ms")
query = torch.rand(batch_size, num_heads, 1, head_dim, dtype=torch.float16).to("cuda")
key = torch.rand(batch_size, num_heads, max_sequence_len, head_dim, dtype=torch.float16).to("cuda")
value = torch.rand(batch_size, num_heads, max_sequence_len, head_dim, dtype=torch.float16).to("cuda")
with torch.no_grad():
with torch.backends.cuda.sdp_kernel(enable_math=False, enable_flash=True, enable_mem_efficient=False):
start_event = torch.cuda.Event(enable_timing=True)
end_event = torch.cuda.Event(enable_timing=True)
torch.cuda.empty_cache()
torch.cuda.synchronize()
start_event.record()
for _ in range(n_runs):
res = torch.nn.functional.scaled_dot_product_attention(query, key, value)
end_event.record()
torch.cuda.synchronize()
tps = start_event.elapsed_time(end_event) / n_runs
print("SDPA non-nested:", tps, "ms")Prints:
SDPA nested: 2.4578866577148437 ms
SDPA non-nested: 1.7974630737304687 ms
cc @cpuhrsch @jbschlosser @bhosmer @drisspg
Alternatives
Not using SDPA but directly https://github.com/Dao-AILab/flash-attention. That would be a shame though because it seems that NestedTensor is a relevant abstraction.
Additional context
PyTorch version: 2.1.0.dev20230724+cu118
Is debug build: False
CUDA used to build PyTorch: 11.8
ROCM used to build PyTorch: N/A
OS: Ubuntu 20.04.5 LTS (x86_64)
GCC version: (Ubuntu 9.4.0-1ubuntu1~20.04.1) 9.4.0
Clang version: Could not collect
CMake version: version 3.26.4
Libc version: glibc-2.31
Python version: 3.9.16 (main, May 15 2023, 23:46:34) [GCC 11.2.0] (64-bit runtime)
Python platform: Linux-5.15.0-1023-aws-x86_64-with-glibc2.31
Is CUDA available: True
CUDA runtime version: Could not collect
CUDA_MODULE_LOADING set to: LAZY
GPU models and configuration:
GPU 0: NVIDIA A100-SXM4-80GB
GPU 1: NVIDIA A100-SXM4-80GB
GPU 2: NVIDIA A100-SXM4-80GB
GPU 3: NVIDIA A100-SXM4-80GB
GPU 4: NVIDIA A100-SXM4-80GB
GPU 5: NVIDIA A100-SXM4-80GB
GPU 6: NVIDIA A100-SXM4-80GB
GPU 7: NVIDIA A100-SXM4-80GB
Nvidia driver version: 510.73.08
cuDNN version: Could not collect
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True
CPU:
Architecture: x86_64
CPU op-mode(s): 32-bit, 64-bit
Byte Order: Little Endian
Address sizes: 46 bits physical, 48 bits virtual
CPU(s): 96
On-line CPU(s) list: 0-95
Thread(s) per core: 2
Core(s) per socket: 24
Socket(s): 2
NUMA node(s): 2
Vendor ID: GenuineIntel
CPU family: 6
Model: 85
Model name: Intel(R) Xeon(R) Platinum 8275CL CPU @ 3.00GHz
Stepping: 7
CPU MHz: 2999.998
BogoMIPS: 5999.99
Hypervisor vendor: KVM
Virtualization type: full
L1d cache: 1.5 MiB
L1i cache: 1.5 MiB
L2 cache: 48 MiB
L3 cache: 71.5 MiB
NUMA node0 CPU(s): 0-23,48-71
NUMA node1 CPU(s): 24-47,72-95
Vulnerability Itlb multihit: KVM: Mitigation: VMX unsupported
Vulnerability L1tf: Mitigation; PTE Inversion
Vulnerability Mds: Vulnerable: Clear CPU buffers attempted, no microcode; SMT Host state unknown
Vulnerability Meltdown: Mitigation; PTI
Vulnerability Mmio stale data: Vulnerable: Clear CPU buffers attempted, no microcode; SMT Host state unknown
Vulnerability Retbleed: Vulnerable
Vulnerability Spec store bypass: Vulnerable
Vulnerability Spectre v1: Mitigation; usercopy/swapgs barriers and __user pointer sanitization
Vulnerability Spectre v2: Mitigation; Retpolines, STIBP disabled, RSB filling, PBRSB-eIBRS Not affected
Vulnerability Srbds: Not affected
Vulnerability Tsx async abort: Not affected
Flags: fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush mmx fxsr sse sse2 ss ht syscall nx pdpe1gb rdtscp lm constant_tsc arch_perfmon rep_good nopl xtopology nonstop_tsc cpuid aperfmperf tsc_known_freq pni pclmulqdq monitor ssse3 fma cx16 pcid sse4_1 sse4_2 x2apic movbe popcnt tsc_deadline_timer aes xsave avx f16c rdrand hypervisor lahf_lm abm 3dnowprefetch invpcid_single pti fsgsbase tsc_adjust bmi1 avx2 smep bmi2 erms invpcid mpx avx512f avx512dq rdseed adx smap clflushopt clwb avx512cd avx512bw avx512vl xsaveopt xsavec xgetbv1 xsaves ida arat pku ospke
Versions of relevant libraries:
[pip3] mypy-protobuf==3.4.0
[pip3] numpy==1.24.3
[pip3] pytorch-triton==2.1.0+9e3e10c5ed
[pip3] torch==2.1.0.dev20230724+cu118
[pip3] torchaudio==2.1.0.dev20230724+cu118
[pip3] torchvision==0.16.0.dev20230724+cu118
[pip3] triton==2.0.0
[conda] numpy 1.24.3 pypi_0 pypi
[conda] pytorch-triton 2.1.0+9e3e10c5ed pypi_0 pypi
[conda] torch 2.1.0.dev20230724+cu118 pypi_0 pypi
[conda] torchaudio 2.1.0.dev20230724+cu118 pypi_0 pypi
[conda] torchvision 0.16.0.dev20230724+cu118 pypi_0 pypi
[conda] triton 2.0.0 pypi_0 pypi