-
Notifications
You must be signed in to change notification settings - Fork 25.7k
dont let partitioner think it can fuse pointwise ops into user triton kernels #136878
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/136878
Note: Links to docs will display an error until the docs builds have been completed. ✅ You can merge normally! (1 Unrelated Failure)As of commit db9375a with merge base f0fa460 ( FLAKY - The following job failed but was likely due to flakiness present on trunk:
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
|
This pull request was exported from Phabricator. Differential Revision: D63551393 |
torch/_functorch/partitioners.py
Outdated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you write a test?
Separately, what did the failure mode look like?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
test added, also updated the description.
The failure mode was that the compiled forward from inductor contained 2 kernels (user_triton, dedicated_inductor_kernel_for_sigmoid), when the "better" case would have been to move the sigmoid() to the backward (so it could be fused into an existing inductor kernel in the backward)
|
This pull request was exported from Phabricator. Differential Revision: D63551393 |
0d75d3a to
8311f22
Compare
8311f22 to
63e502d
Compare
… kernels (pytorch#136878) Summary: Pull Request resolved: pytorch#136878 todo Test Plan: CI Reviewed By: embg Differential Revision: D63551393
… kernels (pytorch#136878) Summary: Pull Request resolved: pytorch#136878 todo Test Plan: CI Reviewed By: embg Differential Revision: D63551393
63e502d to
4651417
Compare
|
This pull request was exported from Phabricator. Differential Revision: D63551393 |
2 similar comments
|
This pull request was exported from Phabricator. Differential Revision: D63551393 |
|
This pull request was exported from Phabricator. Differential Revision: D63551393 |
4651417 to
130bdfe
Compare
130bdfe to
a3bf5a6
Compare
… kernels (pytorch#136878) Summary: Pull Request resolved: pytorch#136878 todo Test Plan: CI Reviewed By: embg Differential Revision: D63551393
|
This pull request was exported from Phabricator. Differential Revision: D63551393 |
| return True | ||
| if can_fuse_into_triton_kernel_wrapper_functional(a, b): | ||
| return True | ||
| if ( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actually, I don't think this is the right thing to do 🤔
I think the root of the problem here is how we treat operator.getitem (we've run into other issues with views in the past). Basically, we're currently treating operator.getitem as a "fusible" op, but it's actually a "free" op/view, and I think that's actually morally different.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Totally agreed - I sent this to Richard but I was going to check with you - what do you think of something like this instead?
a = recursively_remove_getitems(a)
b = recursively_remove_getitems(b)
return op_types.is_fusible(a) and op_types.is_fusible(b)
since as you pointed out, we treat operator.getitem as "always fusible", which seems bad (aka any other ops that return tuples of tensors but are not themselves fusible might suffer in a similar way).
In terms of landing order, I was thinking of landing this change first since it's needed to unblock internal and is a bit less risky.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm fine with landing this first. Out of curiosity, does #126446 also solve this? Perhaps while also including operator.getitem in the view list?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Out of curiosity, does #126446 also solve this? Perhaps while also including operator.getitem in the view list?
Hmm it doesn't look like it (my local copy already has that change to always recompute views, and I also tried tweaking the list to include operator.getitem in the list of views). My local change:
diff --git a/torch/_functorch/partitioners.py b/torch/_functorch/partitioners.py
index 81e2f297f6f..8c02fb68211 100644
--- a/torch/_functorch/partitioners.py
+++ b/torch/_functorch/partitioners.py
@@ -1293,6 +1293,7 @@ def get_default_op_list() -> OpTypes:
aten.as_strided,
aten.permute,
aten.select,
+ operator.getitem,
]
view_ops = recomputable_view_ops
default_recomputable_ops += [
I tried running the same test locally that I have in my PR, and I'm still seeing sigmoid() get saved as an activation (even though it could be fused into an existing inductor backward kernel.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
(my local test):
import torch
import triton
import triton.language as tl
def test_triton_kernel_not_fusable_with_users():
@triton.jit
def _sin_kernel(
in_ptr0,
out_ptr,
out2_ptr,
n_elements,
BLOCK_SIZE: "tl.constexpr",
):
pid = tl.program_id(axis=0)
block_start = pid * BLOCK_SIZE
offsets = block_start + tl.arange(0, BLOCK_SIZE)
mask = offsets < n_elements
x = tl.load(in_ptr0 + offsets, mask=mask)
output = tl.sin(x)
tl.store(out_ptr + offsets, output, mask=mask)
tl.store(out2_ptr + offsets, output, mask=mask)
from typing import List
from torch._library import capture_triton, triton_op
@triton_op("mylib::sin_kernel", mutates_args={})
def sin_kernel(x: torch.Tensor) -> List[torch.Tensor]:
n_elements = x.numel()
out = torch.empty_like(x)
out2 = torch.empty_like(x)
capture_triton(_sin_kernel)[(n_elements,)](
x, out, out2, n_elements, BLOCK_SIZE=4
)
return [out, out2]
class MySin(torch.autograd.Function):
@staticmethod
def forward(ctx, x):
out, saved = tuple(torch.ops.mylib.sin_kernel(x))
ctx.save_for_backward(x, saved)
return out
@staticmethod
def backward(ctx, grad):
(x, saved) = ctx.saved_tensors
return grad * saved.sigmoid() * x
@torch.compile(backend="aot_eager")
def f(x):
return MySin.apply(x)
x = torch.randn(4, 4, requires_grad=True, device='cuda')
out = f(x)
test_triton_kernel_not_fusable_with_users()
a3bf5a6 to
f994cef
Compare
|
This pull request was exported from Phabricator. Differential Revision: D63551393 |
|
@pytorchbot rebase |
|
@pytorchbot started a rebase job onto refs/remotes/origin/viable/strict. Check the current status here |
… kernels (pytorch#136878) Summary: Pull Request resolved: pytorch#136878 todo Test Plan: CI Reviewed By: embg Differential Revision: D63551393
|
Successfully rebased |
f994cef to
db9375a
Compare
|
@pytorchbot merge |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
Merge failedReason: 1 jobs have failed, first few of them are: inductor-periodic / cuda12.1-py3.10-gcc9-sm80 / test (inductor_torchbench_smoketest_perf, 1, 1, linux.gcp.a100) Details for Dev Infra teamRaised by workflow job |
|
@pytorchbot merge |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
Previously if we had a graph like:
The partitioner would assume that the
sigmoid()could be fused into either its user (the pointwise mul), or its producer (the user triton kernel). This could lead to a bad partitioning:(1) If the partitioner thinks we can fuse the sigmoid with its producer triton kernel, we would keep the sigmoid compute in the forward, and have to generate two separate kernels in the forward (user triton kernel, dedicated sigmoid kernel)
(2) if the partitioner puts the sigmoid in the backward instead, we could fuse it with an existing backward kernel (the mul with a tangent)
Reviewed By: embg
Differential Revision: D63551393
cc @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @ipiszy @yf225 @chenyang78 @kadeng @muchulee8 @ColinPeppler @amjames @desertfire @chauhang