KEMBAR78
dont let partitioner think it can fuse pointwise ops into user triton kernels by bdhirsh · Pull Request #136878 · pytorch/pytorch · GitHub
Skip to content

Conversation

@bdhirsh
Copy link
Contributor

@bdhirsh bdhirsh commented Sep 27, 2024

Previously if we had a graph like:

        triton_kernel_wrapper_functional_proxy = triton_kernel_wrapper_functional(...)
        getitem: "f32[3][1]cuda:0" = triton_kernel_wrapper_functional_proxy['out_ptr']
        getitem_1: "f32[3][1]cuda:0" = triton_kernel_wrapper_functional_proxy['out2_ptr']
        sigmoid: "f32[3][1]cuda:0" = torch.ops.aten.sigmoid.default(getitem_1)
        mul: "f32[3][1]cuda:0" = torch.ops.aten.mul.Tensor(tangents_1, sigmoid)

The partitioner would assume that the sigmoid() could be fused into either its user (the pointwise mul), or its producer (the user triton kernel). This could lead to a bad partitioning:

(1) If the partitioner thinks we can fuse the sigmoid with its producer triton kernel, we would keep the sigmoid compute in the forward, and have to generate two separate kernels in the forward (user triton kernel, dedicated sigmoid kernel)

(2) if the partitioner puts the sigmoid in the backward instead, we could fuse it with an existing backward kernel (the mul with a tangent)

Reviewed By: embg

Differential Revision: D63551393

cc @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @ipiszy @yf225 @chenyang78 @kadeng @muchulee8 @ColinPeppler @amjames @desertfire @chauhang

@pytorch-bot
Copy link

pytorch-bot bot commented Sep 27, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/136878

Note: Links to docs will display an error until the docs builds have been completed.

✅ You can merge normally! (1 Unrelated Failure)

As of commit db9375a with merge base f0fa460 (image):

FLAKY - The following job failed but was likely due to flakiness present on trunk:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D63551393

Comment on lines 866 to 873
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you write a test?

Separately, what did the failure mode look like?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

test added, also updated the description.

The failure mode was that the compiled forward from inductor contained 2 kernels (user_triton, dedicated_inductor_kernel_for_sigmoid), when the "better" case would have been to move the sigmoid() to the backward (so it could be fused into an existing inductor kernel in the backward)

@pytorch-bot pytorch-bot bot added the ciflow/trunk Trigger trunk jobs on your pull request label Sep 27, 2024
@bdhirsh bdhirsh added the release notes: composability release notes category label Sep 27, 2024
@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D63551393

bdhirsh added a commit to bdhirsh/pytorch that referenced this pull request Sep 28, 2024
… kernels (pytorch#136878)

Summary:
Pull Request resolved: pytorch#136878

todo

Test Plan: CI

Reviewed By: embg

Differential Revision: D63551393
bdhirsh added a commit to bdhirsh/pytorch that referenced this pull request Sep 28, 2024
… kernels (pytorch#136878)

Summary:
Pull Request resolved: pytorch#136878

todo

Test Plan: CI

Reviewed By: embg

Differential Revision: D63551393
@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D63551393

2 similar comments
@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D63551393

@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D63551393

bdhirsh added a commit to bdhirsh/pytorch that referenced this pull request Sep 30, 2024
… kernels (pytorch#136878)

Summary:
Pull Request resolved: pytorch#136878

todo

Test Plan: CI

Reviewed By: embg

Differential Revision: D63551393
@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D63551393

return True
if can_fuse_into_triton_kernel_wrapper_functional(a, b):
return True
if (
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually, I don't think this is the right thing to do 🤔

I think the root of the problem here is how we treat operator.getitem (we've run into other issues with views in the past). Basically, we're currently treating operator.getitem as a "fusible" op, but it's actually a "free" op/view, and I think that's actually morally different.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Totally agreed - I sent this to Richard but I was going to check with you - what do you think of something like this instead?

a = recursively_remove_getitems(a)
b = recursively_remove_getitems(b)
return op_types.is_fusible(a) and op_types.is_fusible(b)

since as you pointed out, we treat operator.getitem as "always fusible", which seems bad (aka any other ops that return tuples of tensors but are not themselves fusible might suffer in a similar way).

In terms of landing order, I was thinking of landing this change first since it's needed to unblock internal and is a bit less risky.

Copy link
Collaborator

@Chillee Chillee Sep 30, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm fine with landing this first. Out of curiosity, does #126446 also solve this? Perhaps while also including operator.getitem in the view list?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Out of curiosity, does #126446 also solve this? Perhaps while also including operator.getitem in the view list?

Hmm it doesn't look like it (my local copy already has that change to always recompute views, and I also tried tweaking the list to include operator.getitem in the list of views). My local change:

diff --git a/torch/_functorch/partitioners.py b/torch/_functorch/partitioners.py
index 81e2f297f6f..8c02fb68211 100644
--- a/torch/_functorch/partitioners.py
+++ b/torch/_functorch/partitioners.py
@@ -1293,6 +1293,7 @@ def get_default_op_list() -> OpTypes:
         aten.as_strided,
         aten.permute,
         aten.select,
+        operator.getitem,
     ]
     view_ops = recomputable_view_ops
     default_recomputable_ops += [

I tried running the same test locally that I have in my PR, and I'm still seeing sigmoid() get saved as an activation (even though it could be fused into an existing inductor backward kernel.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(my local test):

import torch
import triton
import triton.language as tl

def test_triton_kernel_not_fusable_with_users():
    @triton.jit
    def _sin_kernel(
        in_ptr0,
        out_ptr,
        out2_ptr,
        n_elements,
        BLOCK_SIZE: "tl.constexpr",
    ):
        pid = tl.program_id(axis=0)
        block_start = pid * BLOCK_SIZE
        offsets = block_start + tl.arange(0, BLOCK_SIZE)
        mask = offsets < n_elements
        x = tl.load(in_ptr0 + offsets, mask=mask)
        output = tl.sin(x)
        tl.store(out_ptr + offsets, output, mask=mask)
        tl.store(out2_ptr + offsets, output, mask=mask)

    from typing import List

    from torch._library import capture_triton, triton_op

    @triton_op("mylib::sin_kernel", mutates_args={})
    def sin_kernel(x: torch.Tensor) -> List[torch.Tensor]:
        n_elements = x.numel()
        out = torch.empty_like(x)
        out2 = torch.empty_like(x)
        capture_triton(_sin_kernel)[(n_elements,)](
            x, out, out2, n_elements, BLOCK_SIZE=4
        )
        return [out, out2]

    class MySin(torch.autograd.Function):
        @staticmethod
        def forward(ctx, x):
            out, saved = tuple(torch.ops.mylib.sin_kernel(x))
            ctx.save_for_backward(x, saved)
            return out

        @staticmethod
        def backward(ctx, grad):
            (x, saved) = ctx.saved_tensors
            return grad * saved.sigmoid() * x

    @torch.compile(backend="aot_eager")
    def f(x):
        return MySin.apply(x)

    x = torch.randn(4, 4, requires_grad=True, device='cuda')
    out = f(x)

test_triton_kernel_not_fusable_with_users()

@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D63551393

@bdhirsh
Copy link
Contributor Author

bdhirsh commented Oct 2, 2024

@pytorchbot rebase

@pytorchmergebot
Copy link
Collaborator

@pytorchbot started a rebase job onto refs/remotes/origin/viable/strict. Check the current status here

… kernels (pytorch#136878)

Summary:
Pull Request resolved: pytorch#136878

todo

Test Plan: CI

Reviewed By: embg

Differential Revision: D63551393
@pytorchmergebot
Copy link
Collaborator

Successfully rebased export-D63551393 onto refs/remotes/origin/viable/strict, please pull locally before adding more changes (for example, via git checkout export-D63551393 && git pull --rebase)

@bdhirsh
Copy link
Contributor Author

bdhirsh commented Oct 2, 2024

@pytorchbot merge

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

@pytorchmergebot
Copy link
Collaborator

Merge failed

Reason: 1 jobs have failed, first few of them are: inductor-periodic / cuda12.1-py3.10-gcc9-sm80 / test (inductor_torchbench_smoketest_perf, 1, 1, linux.gcp.a100)

Details for Dev Infra team Raised by workflow job

@bdhirsh
Copy link
Contributor Author

bdhirsh commented Oct 2, 2024

@pytorchbot merge

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants