KEMBAR78
Add host-side Triton TMA support to Inductor by aakhundov · Pull Request #137950 · pytorch/pytorch · GitHub
Skip to content

Conversation

@aakhundov
Copy link
Contributor

@aakhundov aakhundov commented Oct 15, 2024

Stack from ghstack (oldest at bottom):

This adds Dynamo tracing support for the host-side Triton TMA API (see create_2d_tma_descriptor calls on the host in the Triton tutorial). A few notes:

  • Here we assume the availability of the host-side TMA API added to upstream Triton in [nvidia] Support passing TMA descriptors by-value triton-lang/triton#4498. As of time of writing, this is not a part of the PT2 OSS Triton pin (although back-ported internally). OSS Triton pin update should be done in December 2024.
  • Due to Dynamo support implemented in the previous PR, the tma_descriptor_metadata dict is delivered to the triton_kerenl_wrap_ lowering and passed to the ir.UserDefinedTritonKernel as additional argument.
  • Looking into the tma_descriptor_metadata, ir.UserDefinedTritonKernel substitutes the corresponding TensorBox arguments of the kernel (swapped upstream in Dynamo) by the new ir.TMADescriptor nodes implementing TMA descriptors in Inductor IR.
  • ir.TMADescriptor.__init__ provides the wiring between the upstream underlying ir.TensorBox and the downstream ir.UserDefinedTritonKernel kernel. In particular, we use ir.NonOwnedLayout wrapping ir.ReinterpretView to avoid the upstream tensor's buffer being deleted prematurely (before the TMA descriptor is used in the Triton kernel).
  • Via ir.TMADescriptor.codegen, the Triton's create_{1d,2d}_tma_descriptor function call is codegened in the wrapper (in the host code).
  • New TMADescriptorArg dataclass is added to handle the Triton kernel metadata pertinent to host-side TMA.
  • AOT Inductor support will be implemented in a follow-up PR.

cc @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @ipiszy @yf225 @chenyang78 @kadeng @muchulee8 @ColinPeppler @amjames @desertfire @chauhang @rec

[ghstack-poisoned]
@pytorch-bot
Copy link

pytorch-bot bot commented Oct 15, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/137950

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit 0d93a4e with merge base 4a8e493 (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

[ghstack-poisoned]
@aakhundov aakhundov marked this pull request as ready for review October 15, 2024 04:38
@aakhundov aakhundov requested review from eellison and oulgen October 15, 2024 04:38
[ghstack-poisoned]
aakhundov added a commit that referenced this pull request Oct 15, 2024
Details TBA

ghstack-source-id: 3ec0a5d
Pull Request resolved: #137950
Copy link
Contributor

@eellison eellison left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good ! would you mind adding a couple symint uses ? Also, would be cool to have a deduping mechanism for tma_descriptor on same tensor

autotune_configs=configs,
)

def generate_tma_descriptor(self, desc):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Comment on lines +5274 to +5278
constant_args = [
*self.dims,
*self.block_dims,
self.element_size,
]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is dims here actually constant ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not necessarily, they can be SymInts. Was following what the UserDefinedTritonKernel does, where all non-TensorBox args are put into the constant_args here.

I was unsure about the semantics of the constant_args parameter of ExternKernel. Looking into the code, seems self.constant_args is mostly used in the codegen-related methods, which are not relevant for TMADescriptor (neither for UserDefinedTritonKernel), as the codegen is overridden at the root. Although, I also see it being used as a potential source of unbacked SymInts here. So perhaps I should keep this code as is?

Comment on lines +5282 to +5285
# link back to the underlying tensor in terms of ownership
# to avoid getting the underlying tensor deleted *before*
# the TMADescriptor node can be deleted.
NonOwningLayout(ReinterpretView(tensor, tensor.get_layout())),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess this works and I can't think of anything better - but if this comes up as a common pattern maybe we can loop back

@aakhundov
Copy link
Contributor Author

would you mind adding a couple symint uses ?

You mean test cases? The unit tests run with dynamic=True introduce SymInts into the dims passed to the create_{1d,2d}_tma_descriptor calls. Would that be sufficient?

Also, would be cool to have a deduping mechanism for tma_descriptor on same tensor

TMA descriptors are immutable, so this shouldn't be hard to do: would just need hashing on the underlying tensor and all the args. Let me try.

[ghstack-poisoned]
aakhundov added a commit that referenced this pull request Oct 17, 2024
Details TBA

ghstack-source-id: a248dd1
Pull Request resolved: #137950
block_dims: List[Union[int, torch.SymInt]],
element_size: Optional[int] = None,
):
key = (id(tensor), dims, block_dims, element_size)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@eellison I'm using id(tensor) here because TensorBox happens to be non-hashable. Although this looks correct, but it's very restrictive: we require the same TensorBox Python object to hit the cache. Is there a more canonical way to do this in Inductor IR (that would allow, e.g., different TensorBoxes referring to the same underlying storage)? Maybe I should unwrap storage before doing id(...), or can this ignore offsets in the views which can lead to different data_ptr() values?

Copy link
Contributor

@eellison eellison Oct 18, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You could try to cover different Tensors with same strides, same underlying storage, but even then, would require us to fix the layout. And not sure how common that is. I think this is good.

@aakhundov
Copy link
Contributor Author

Landing this as the signals look good and all comments resolved. Happy to address further requests in a follow-up PR.

@aakhundov
Copy link
Contributor Author

@pytorchbot merge

@pytorch-bot pytorch-bot bot added the ciflow/trunk Trigger trunk jobs on your pull request label Oct 18, 2024
@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

smalltalkman pushed a commit to smalltalkman/pytorch that referenced this pull request Oct 18, 2024
This adds Dynamo tracing support for the host-side Triton TMA API (see `create_2d_tma_descriptor` calls on the host in the [Triton tutorial](https://triton-lang.org/main/getting-started/tutorials/09-persistent-matmul.html#sphx-glr-getting-started-tutorials-09-persistent-matmul-py)). A few notes:

- Here we assume the availability of the host-side TMA API added to upstream Triton in triton-lang/triton#4498. As of time of writing, this is not a part of the PT2 OSS Triton pin (although back-ported internally). OSS Triton pin update should be done in December 2024.
- Due to Dynamo support implemented in the previous PR, the `tma_descriptor_metadata` dict is delivered to the `triton_kerenl_wrap_` lowering and passed to the `ir.UserDefinedTritonKernel` as additional argument.
- Looking into the `tma_descriptor_metadata`, `ir.UserDefinedTritonKernel` substitutes the corresponding `TensorBox` arguments of the kernel (swapped upstream in Dynamo) by the new `ir.TMADescriptor` nodes implementing TMA descriptors in Inductor IR.
- `ir.TMADescriptor.__init__` provides the wiring between the upstream underlying `ir.TensorBox` and the downstream `ir.UserDefinedTritonKernel` kernel. In particular, we use `ir.NonOwnedLayout` wrapping `ir.ReinterpretView` to avoid the upstream tensor's buffer being deleted prematurely (before the TMA descriptor is used in the Triton kernel).
- Via `ir.TMADescriptor.codegen`, the Triton's `create_{1d,2d}_tma_descriptor` function call is codegened in the wrapper (in the host code).
- New `TMADescriptorArg` dataclass is added to handle the Triton kernel metadata pertinent to host-side TMA.
- AOT Inductor support will be implemented in a follow-up PR.

Pull Request resolved: pytorch#137950
Approved by: https://github.com/eellison
ghstack dependencies: pytorch#137677
@pytorchmergebot
Copy link
Collaborator

This PR (#137950) was merged in d116d00 but it is still open, likely due to a Github bug, so mergebot is closing it manually. If you think this is a mistake, please feel free to reopen and contact Dev Infra.

@github-actions github-actions bot deleted the gh/aakhundov/11/head branch November 18, 2024 02:10
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants