-
Notifications
You must be signed in to change notification settings - Fork 25.7k
[SDPA] Allow user-specified priority order with context manager #140467
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/140467
Note: Links to docs will display an error until the docs builds have been completed. ✅ No FailuresAs of commit da74e3d with merge base d0fd42e ( This comment was automatically generated by Dr. CI and updates every 15 minutes. |
|
I agree w/ the premise here, will do a review tomorrow. I dont know if I like jamming the list into sdpa_kernel.. but let me sleep on it |
|
It would make sense for SDPA kernel to me. The old API is a little bit odd because the order of the backends has no effect on anything, wasting useful information. Really should take a set previously if the order did not matter. |
|
What if the user never wanted to fallback to MATH backend? For example: |
If a backend is not in the list it will not be selected as a fallback, this sets the priority order for all backends but preserves the current behavior of only enabling the ones in the list. |
Yeah, but is there case where you want priority, but want it to fallback to kernels outside of the list? |
|
Not really, this is just a housekeeping thing as priority is currently a fixed-size array so there will be "something" in the priority list even if it's nonsensical |
|
My magic wand API is only accept the list for this context manager and have it specify the order. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I doubt it would be perceptible, but I wonder if this adds any slowdown.
aten/src/ATen/SDPBackend.h
Outdated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This technically changes the namespace of the symbol, right?
torch/nn/attention/__init__.py
Outdated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is a good middle ground though of the magic wand API.
|
@pytorchmergebot rebase |
|
@pytorchbot started a rebase job onto refs/remotes/origin/viable/strict. Check the current status here |
|
Successfully rebased |
4c59817 to
71420c1
Compare
|
@pytorchmergebot merge |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
Merge failedReason: 1 jobs have failed, first few of them are: trunk / linux-focal-cuda12.4-py3.10-gcc9-sm86 / build Details for Dev Infra teamRaised by workflow job |
|
@pytorchmergebot merge |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
…rch#140467) TODO: docs changes? For better debuggability of issues like pytorch#139298 Better testing, current sketch: ``` Python import torch from torch.nn.functional import scaled_dot_product_attention from torch.nn.attention import SDPBackend, sdpa_kernel q = torch.randn(64, 1024, 8, 64, dtype=torch.half, device='cuda') print(torch._C._get_sdp_priority_order()) orders = [[SDPBackend.CUDNN_ATTENTION, SDPBackend.MATH, SDPBackend.EFFICIENT_ATTENTION], [SDPBackend.MATH, SDPBackend.CUDNN_ATTENTION, SDPBackend.EFFICIENT_ATTENTION], [SDPBackend.EFFICIENT_ATTENTION, SDPBackend.CUDNN_ATTENTION, SDPBackend.MATH]] import time times = list() for order in orders: print(order) with sdpa_kernel(order, set_priority=True): scaled_dot_product_attention(q, q, q) torch.cuda.synchronize() t0 = time.perf_counter() with sdpa_kernel(order, set_priority=True): scaled_dot_product_attention(q, q, q) torch.cuda.synchronize() t1 = time.perf_counter() times.append(t1 - t0) print(times) assert times[0] < times[1] assert times[0] > times[2] assert times[1] > times[2] print(torch._C._get_sdp_priority_order()) ``` Pull Request resolved: pytorch#140467 Approved by: https://github.com/drisspg
…rch#140467) TODO: docs changes? For better debuggability of issues like pytorch#139298 Better testing, current sketch: ``` Python import torch from torch.nn.functional import scaled_dot_product_attention from torch.nn.attention import SDPBackend, sdpa_kernel q = torch.randn(64, 1024, 8, 64, dtype=torch.half, device='cuda') print(torch._C._get_sdp_priority_order()) orders = [[SDPBackend.CUDNN_ATTENTION, SDPBackend.MATH, SDPBackend.EFFICIENT_ATTENTION], [SDPBackend.MATH, SDPBackend.CUDNN_ATTENTION, SDPBackend.EFFICIENT_ATTENTION], [SDPBackend.EFFICIENT_ATTENTION, SDPBackend.CUDNN_ATTENTION, SDPBackend.MATH]] import time times = list() for order in orders: print(order) with sdpa_kernel(order, set_priority=True): scaled_dot_product_attention(q, q, q) torch.cuda.synchronize() t0 = time.perf_counter() with sdpa_kernel(order, set_priority=True): scaled_dot_product_attention(q, q, q) torch.cuda.synchronize() t1 = time.perf_counter() times.append(t1 - t0) print(times) assert times[0] < times[1] assert times[0] > times[2] assert times[1] > times[2] print(torch._C._get_sdp_priority_order()) ``` Pull Request resolved: pytorch#140467 Approved by: https://github.com/drisspg
TODO: docs changes?
For better debuggability of issues like #139298
Better testing, current sketch:
cc @csarofeen @ptrblck @xwang233 @drisspg @mikaylagawarecki