KEMBAR78
[Inductor] Add Dynamic shape support to user defined triton kernels by oulgen · Pull Request #112523 · pytorch/pytorch · GitHub
Skip to content

Conversation

@oulgen
Copy link
Contributor

@oulgen oulgen commented Oct 31, 2023

1) This PR moves the grid function codegen to wrapper so that we can use
   IndentBuffers as opposed to manually adding tabs for indentation.
2) In inductor, emits the grid function in the body of the kernel call so
   that it can use free symbols from dynamic shapes

[ghstack-poisoned]
@pytorch-bot
Copy link

pytorch-bot bot commented Oct 31, 2023

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/112523

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit 25684fb with merge base 5a6f801 (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

Comment on lines +517 to +520
grid, code = user_defined_kernel_grid_fn_code(kernel_name, configs, grid)
# Must happen after free symbols are already codegened
with self.prefix.indent():
self.prefix.splice(code)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looking for feedback here. I moved the grid function to the prefix (rather than the header) so that it can access the free symbols from dynamic shapes. This happens to work because prefix contains the free symbols. Is there a better solution or is there a way for me to assert that free symbols are already generated?

Copy link
Contributor

@aakhundov aakhundov Nov 1, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Because the grid function is a function (and not some inlined block of code) would all the symbols defined in the outer scope, including the ones defined after this function, not be visible in the inner scope of this function's body? E.g., this works (in Python):

def fn(a):
  return a * b

b = 123
c = fn(456)

The main thing is that fn is called before the b is defined. But otherwise it's fine that b is defined after fn. So we need to make sure that the calls to Triton kernels are codegened after the required symbols' definitions, but the grid functions can as well be anywhere in the call function's body before the kernel is called?

As for whether the required symbols will be codegened before the call, as the Triton kernel is represented by a Buffer in the IR which has dependencies, I'd hope that the existing dependency management mechanics should take care of this.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I didn't know this is how python semantics worked.. C++ certainly does not work this way

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Depending on how we decide to deal with codegening the grid in the AOTInductor's C++ wrapper codegen, maybe we could opt for relying on this feature of Python.

@oulgen oulgen added ciflow/trunk Trigger trunk jobs on your pull request topic: not user facing topic category ciflow/rocm Trigger "default" config CI on ROCm labels Oct 31, 2023
for grid, c in zip(grids, configs):
guards = [f"meta['{name}'] == {val}" for name, val in c.kwargs.items()]
guards = " and ".join(guards)
output.writeline(f"if {guards}: return {grid}")
Copy link
Contributor

@aakhundov aakhundov Nov 1, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we also generate an exception at the end saying smth like f"no matching Triton config found for the kernel {name=} and {meta=}"? Otherwise, the function will return None and I'm not sure how clear the downstream error would be.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Triton will error saying that no suitable grid was found for {name}. But yes, we will not know what meta looked like. I could add the exception but by construction aren't we always guaranteed to match? Like what's scenario that we would fail to match?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure about the scenario, more like a defensive / informative code against the unexpected :)

@Chillee
Copy link
Collaborator

Chillee commented Nov 1, 2023

Can you provide an example of what the generated code looks like?

@oulgen
Copy link
Contributor Author

oulgen commented Nov 1, 2023

@Chillee

def call(args):
    arg0_1, arg1_1, arg2_1, arg3_1 = args
    args.clear()
    s0 = arg0_1
    assert_size_stride(arg1_1, (s0, ), (1, ))
    assert_size_stride(arg2_1, (s0, ), (1, ))
    def grid_wrapper_for_add_kernel_autotuned_0(meta):
        if meta['BLOCK_SIZE'] == 128: return (((s0 + 127)//128), 1, 1)
        if meta['BLOCK_SIZE'] == 64: return (((s0 + 63)//64), 1, 1)
    with torch.cuda._DeviceGuard(0):
        torch.cuda.set_device(0) # no-op to ensure context
        buf0 = empty((s0, ), device='cuda', dtype=torch.float32)
        # Source Nodes: [output], Original ATen: [aten.zeros_like]
        stream0 = get_cuda_stream(0)
        triton_poi_fused_zeros_like_0.run(buf0, s0, grid=grid(s0), stream=stream0)
        # Source Nodes: [triton_kernel_wrapper_mutation], Original ATen: []
        add_kernel_autotuned_0.run(in_ptr0=arg1_1, in_ptr1=arg2_1, out_ptr=buf0, n_elements=s0, grid=grid_wrapper_for_add_kernel_autotuned_0, stream=stream0)
        del arg1_1
        del arg2_1
        return (buf0, s0, )

Full code: P870612754


@triton_kernel_wrapper_mutation.py_impl(DispatchKey.CompositeExplicitAutograd)
def triton_kernel_wrapper_mutation_dense(*, kernel_idx, grid, kwargs):
from torch._inductor.codegen.wrapper import user_defined_kernel_grid_fn_code
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I missed some PRs: previously, we were evaluating the grid at Dynamo time. Did that change?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We are still doing the same, this is just for emitting the python code for multi option grid function.

From above

This PR moves the grid function codegen to wrapper so that we can use IndentBuffers as opposed to manually adding tabs for indentation.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

gotcha, thanks for clarifying

@Chillee
Copy link
Collaborator

Chillee commented Nov 1, 2023

@oulgen And what's the user code for that generated inductor code?

@oulgen
Copy link
Contributor Author

oulgen commented Nov 1, 2023

@Chillee

        def call_triton(x: torch.Tensor, y: torch.Tensor):
            output = torch.zeros_like(x, requires_grad=grad)
            n_elements = output.numel()
            grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),)
            add_kernel_autotuned[grid](x, y, output, n_elements)
            return output

        t1 = torch.rand(256, device="cuda", requires_grad=grad)
        t2 = torch.rand(256, device="cuda", requires_grad=grad)

from the test cases

Is there anything particular you're looking at?

…n kernels"

1) This PR moves the grid function codegen to wrapper so that we can use
   IndentBuffers as opposed to manually adding tabs for indentation.
2) In inductor, emits the grid function in the body of the kernel call so
   that it can use free symbols from dynamic shapes

cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx peterbell10 ipiszy yf225 chenyang78 kadeng muchulee8 aakhundov ColinPeppler

[ghstack-poisoned]
Copy link
Collaborator

@Chillee Chillee left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How is the grid function evaluated at dynamo time for all of the config options? is that code in this PR?

@oulgen
Copy link
Contributor Author

oulgen commented Nov 2, 2023

@Chillee

https://github.com/pytorch/pytorch/blob/main/torch/_dynamo/variables/functions.py#L715-L726

t = torch.rand(4, 4, device="cuda")
t_view = t.view(16)

compiled_func = torch.compile(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should add some more concrete tests about recompilation here with dynamic shapes.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

will do as follow up

@oulgen
Copy link
Contributor Author

oulgen commented Nov 2, 2023

@pytorchbot merge

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

xuhancn pushed a commit to xuhancn/pytorch that referenced this pull request Nov 7, 2023
…ytorch#112523)

1) This PR moves the grid function codegen to wrapper so that we can use
   IndentBuffers as opposed to manually adding tabs for indentation.
2) In inductor, emits the grid function in the body of the kernel call so
   that it can use free symbols from dynamic shapes

Pull Request resolved: pytorch#112523
Approved by: https://github.com/Chillee
Skylion007 pushed a commit to Skylion007/pytorch that referenced this pull request Nov 14, 2023
…ytorch#112523)

1) This PR moves the grid function codegen to wrapper so that we can use
   IndentBuffers as opposed to manually adding tabs for indentation.
2) In inductor, emits the grid function in the body of the kernel call so
   that it can use free symbols from dynamic shapes

Pull Request resolved: pytorch#112523
Approved by: https://github.com/Chillee
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ciflow/inductor ciflow/rocm Trigger "default" config CI on ROCm ciflow/trunk Trigger trunk jobs on your pull request Merged module: dynamo module: inductor topic: not user facing topic category

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants