KEMBAR78
[SDPA] Allow user-specified priority order with context manager by eqy · Pull Request #140467 · pytorch/pytorch · GitHub
Skip to content

Conversation

@eqy
Copy link
Collaborator

@eqy eqy commented Nov 12, 2024

TODO: docs changes?
For better debuggability of issues like #139298

Better testing, current sketch:

import torch
from torch.nn.functional import scaled_dot_product_attention
from torch.nn.attention import SDPBackend, sdpa_kernel

q = torch.randn(64, 1024, 8, 64, dtype=torch.half, device='cuda')
print(torch._C._get_sdp_priority_order())

orders = [[SDPBackend.CUDNN_ATTENTION, SDPBackend.MATH, SDPBackend.EFFICIENT_ATTENTION],
          [SDPBackend.MATH, SDPBackend.CUDNN_ATTENTION, SDPBackend.EFFICIENT_ATTENTION],
          [SDPBackend.EFFICIENT_ATTENTION, SDPBackend.CUDNN_ATTENTION, SDPBackend.MATH]]
import time
times = list()
for order in orders:
    print(order)
    with sdpa_kernel(order, set_priority=True):
        scaled_dot_product_attention(q, q, q)
    torch.cuda.synchronize()
    t0 = time.perf_counter()
    with sdpa_kernel(order, set_priority=True):
        scaled_dot_product_attention(q, q, q)
    torch.cuda.synchronize()
    t1 = time.perf_counter()
    times.append(t1 - t0)
print(times)
assert times[0] < times[1]
assert times[0] > times[2]
assert times[1] > times[2]
print(torch._C._get_sdp_priority_order())

cc @csarofeen @ptrblck @xwang233 @drisspg @mikaylagawarecki

@eqy eqy added module: cudnn Related to torch.backends.cudnn, and CuDNN support open source topic: improvements topic category topic: not user facing topic category module: multi-headed-attention labels Nov 12, 2024
@eqy eqy requested a review from drisspg November 12, 2024 23:16
@pytorch-bot
Copy link

pytorch-bot bot commented Nov 12, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/140467

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit da74e3d with merge base d0fd42e (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@drisspg
Copy link
Contributor

drisspg commented Nov 13, 2024

I agree w/ the premise here, will do a review tomorrow. I dont know if I like jamming the list into sdpa_kernel.. but let me sleep on it

@soulitzer soulitzer added the triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module label Nov 13, 2024
@Skylion007
Copy link
Collaborator

It would make sense for SDPA kernel to me. The old API is a little bit odd because the order of the backends has no effect on anything, wasting useful information. Really should take a set previously if the order did not matter.

@Skylion007
Copy link
Collaborator

What if the user never wanted to fallback to MATH backend? For example:
order=[SDPBackend.EFFICIENT_ATTENTION, SDPBackend.CUDNN_ATTENTION,]]
seems like it should be valid / supported.

@eqy
Copy link
Collaborator Author

eqy commented Nov 13, 2024

What if the user never wanted to fallback to MATH backend? For example: order=[SDPBackend.EFFICIENT_ATTENTION, SDPBackend.CUDNN_ATTENTION,]] seems like it should be valid / supported.

If a backend is not in the list it will not be selected as a fallback, this sets the priority order for all backends but preserves the current behavior of only enabling the ones in the list.

@Skylion007
Copy link
Collaborator

What if the user never wanted to fallback to MATH backend? For example: order=[SDPBackend.EFFICIENT_ATTENTION, SDPBackend.CUDNN_ATTENTION,]] seems like it should be valid / supported.

If a backend is not in the list it will not be selected as a fallback, this sets the priority order for all backends but preserves the current behavior of only enabling the ones in the list.

Yeah, but is there case where you want priority, but want it to fallback to kernels outside of the list?

@eqy
Copy link
Collaborator Author

eqy commented Nov 14, 2024

Not really, this is just a housekeeping thing as priority is currently a fixed-size array so there will be "something" in the priority list even if it's nonsensical

@drisspg
Copy link
Contributor

drisspg commented Nov 18, 2024

My magic wand API is only accept the list for this context manager and have it specify the order.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I doubt it would be perceptible, but I wonder if this adds any slowdown.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This technically changes the namespace of the symbol, right?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a good middle ground though of the magic wand API.

@eqy
Copy link
Collaborator Author

eqy commented Nov 25, 2024

@pytorchmergebot rebase

@pytorchmergebot
Copy link
Collaborator

@pytorchbot started a rebase job onto refs/remotes/origin/viable/strict. Check the current status here

@pytorchmergebot
Copy link
Collaborator

Successfully rebased prioritysdpa2 onto refs/remotes/origin/viable/strict, please pull locally before adding more changes (for example, via git checkout prioritysdpa2 && git pull --rebase)

@drisspg drisspg added module: sdpa All things related to torch.nn.functional.scaled_dot_product_attentiion and removed module: multi-headed-attention labels Nov 27, 2024
@eqy
Copy link
Collaborator Author

eqy commented Dec 5, 2024

@pytorchmergebot merge

@pytorch-bot pytorch-bot bot added the ciflow/trunk Trigger trunk jobs on your pull request label Dec 5, 2024
@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

@albanD albanD removed their request for review December 5, 2024 23:06
@pytorchmergebot
Copy link
Collaborator

Merge failed

Reason: 1 jobs have failed, first few of them are: trunk / linux-focal-cuda12.4-py3.10-gcc9-sm86 / build

Details for Dev Infra team Raised by workflow job

@eqy
Copy link
Collaborator Author

eqy commented Dec 6, 2024

@pytorchmergebot merge

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

AmdSampsa pushed a commit to AmdSampsa/pytorch that referenced this pull request Dec 9, 2024
…rch#140467)

TODO: docs changes?
For better debuggability of issues like pytorch#139298

Better testing, current sketch:

``` Python
import torch
from torch.nn.functional import scaled_dot_product_attention
from torch.nn.attention import SDPBackend, sdpa_kernel

q = torch.randn(64, 1024, 8, 64, dtype=torch.half, device='cuda')
print(torch._C._get_sdp_priority_order())

orders = [[SDPBackend.CUDNN_ATTENTION, SDPBackend.MATH, SDPBackend.EFFICIENT_ATTENTION],
          [SDPBackend.MATH, SDPBackend.CUDNN_ATTENTION, SDPBackend.EFFICIENT_ATTENTION],
          [SDPBackend.EFFICIENT_ATTENTION, SDPBackend.CUDNN_ATTENTION, SDPBackend.MATH]]
import time
times = list()
for order in orders:
    print(order)
    with sdpa_kernel(order, set_priority=True):
        scaled_dot_product_attention(q, q, q)
    torch.cuda.synchronize()
    t0 = time.perf_counter()
    with sdpa_kernel(order, set_priority=True):
        scaled_dot_product_attention(q, q, q)
    torch.cuda.synchronize()
    t1 = time.perf_counter()
    times.append(t1 - t0)
print(times)
assert times[0] < times[1]
assert times[0] > times[2]
assert times[1] > times[2]
print(torch._C._get_sdp_priority_order())
```

Pull Request resolved: pytorch#140467
Approved by: https://github.com/drisspg
alugorey pushed a commit to alugorey/pytorch that referenced this pull request Mar 24, 2025
…rch#140467)

TODO: docs changes?
For better debuggability of issues like pytorch#139298

Better testing, current sketch:

``` Python
import torch
from torch.nn.functional import scaled_dot_product_attention
from torch.nn.attention import SDPBackend, sdpa_kernel

q = torch.randn(64, 1024, 8, 64, dtype=torch.half, device='cuda')
print(torch._C._get_sdp_priority_order())

orders = [[SDPBackend.CUDNN_ATTENTION, SDPBackend.MATH, SDPBackend.EFFICIENT_ATTENTION],
          [SDPBackend.MATH, SDPBackend.CUDNN_ATTENTION, SDPBackend.EFFICIENT_ATTENTION],
          [SDPBackend.EFFICIENT_ATTENTION, SDPBackend.CUDNN_ATTENTION, SDPBackend.MATH]]
import time
times = list()
for order in orders:
    print(order)
    with sdpa_kernel(order, set_priority=True):
        scaled_dot_product_attention(q, q, q)
    torch.cuda.synchronize()
    t0 = time.perf_counter()
    with sdpa_kernel(order, set_priority=True):
        scaled_dot_product_attention(q, q, q)
    torch.cuda.synchronize()
    t1 = time.perf_counter()
    times.append(t1 - t0)
print(times)
assert times[0] < times[1]
assert times[0] > times[2]
assert times[1] > times[2]
print(torch._C._get_sdp_priority_order())
```

Pull Request resolved: pytorch#140467
Approved by: https://github.com/drisspg
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ciflow/trunk Trigger trunk jobs on your pull request Merged module: cudnn Related to torch.backends.cudnn, and CuDNN support module: sdpa All things related to torch.nn.functional.scaled_dot_product_attentiion open source topic: improvements topic category topic: not user facing topic category triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants