KEMBAR78
Add host-side TMA support to AOTInductor by aakhundov · Pull Request #138878 · pytorch/pytorch · GitHub
Skip to content

Conversation

@aakhundov
Copy link
Contributor

@aakhundov aakhundov commented Oct 25, 2024

Stack from ghstack (oldest at bottom):

This adds host-side Triton TMA support to AOTInductor. Notes:

  • Two helper functions, init1DTMADescriptor and init2DTMADescriptor are added to the C++ wrapper codegen on GPU, conditioned on the model having user-defined Triton kernels with host-side TMA (CUDA-specific).
  • C++ wrapper codegen on GPU emits TMA descriptor initialization via the aforementioned helper functions.
  • Special handling added for the TMA descriptors (in the Python wrapper codegen) during the compile-time autotuning, as the underlying tensor can't be passed directly to the user-defined Triton kernel. TMA descriptors are generated in-between the source tensor's buffer and the kernel call, like in the full Python wrapper codegen.
  • This PR concludes the host-side Triton TMA support in PT2.

cc @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @ipiszy @yf225 @chenyang78 @kadeng @muchulee8 @ColinPeppler @amjames @desertfire @chauhang

[ghstack-poisoned]
@pytorch-bot
Copy link

pytorch-bot bot commented Oct 25, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/138878

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit e417138 with merge base 72ea7ba (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

aakhundov added a commit that referenced this pull request Oct 25, 2024
ghstack-source-id: c9551a1
Pull Request resolved: #138878
@aakhundov aakhundov marked this pull request as draft October 25, 2024 01:13
@aakhundov aakhundov added topic: not user facing topic category ciflow/trunk Trigger trunk jobs on your pull request labels Oct 25, 2024
@aakhundov aakhundov marked this pull request as ready for review October 25, 2024 04:52
autotune_configs=configs,
)

def generate_tma_descriptor(self, desc):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Because AOTI still uses two-pass run at the moment, does the cpp wrapper codegen here need to read any information from the python run? I have added things like DeferredGpuGridLine for that purpose. Do you think it is necessary to do something like that here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I believe, the Python and C++ codegens for the TMA descriptor creation are independent.

[ghstack-poisoned]
aakhundov added a commit that referenced this pull request Oct 27, 2024
ghstack-source-id: ed091c0
Pull Request resolved: #138878
@aakhundov aakhundov requested a review from desertfire October 27, 2024 23:24
@aakhundov
Copy link
Contributor Author

@pytorchbot merge

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

rahulsingh-intel pushed a commit to rahulsingh-intel/pytorch that referenced this pull request Oct 29, 2024
This adds host-side Triton TMA support to AOTInductor. Notes:

- Two helper functions, `init1DTMADescriptor` and `init2DTMADescriptor` are added to the C++ wrapper codegen on GPU, conditioned on the model having user-defined Triton kernels with host-side TMA (CUDA-specific).
- C++ wrapper codegen on GPU emits TMA descriptor initialization via the aforementioned helper functions.
- Special handling added for the TMA descriptors (in the Python wrapper codegen) during the compile-time autotuning, as the underlying tensor can't be passed directly to the user-defined Triton kernel. TMA descriptors are generated in-between the source tensor's buffer and the kernel call, like in the full Python wrapper codegen.
- This PR concludes the host-side Triton TMA support in PT2.

Pull Request resolved: pytorch#138878
Approved by: https://github.com/desertfire, https://github.com/chenyang78
ghstack dependencies: pytorch#138759, pytorch#138877
rahulsingh-intel pushed a commit to rahulsingh-intel/pytorch that referenced this pull request Nov 5, 2024
This adds host-side Triton TMA support to AOTInductor. Notes:

- Two helper functions, `init1DTMADescriptor` and `init2DTMADescriptor` are added to the C++ wrapper codegen on GPU, conditioned on the model having user-defined Triton kernels with host-side TMA (CUDA-specific).
- C++ wrapper codegen on GPU emits TMA descriptor initialization via the aforementioned helper functions.
- Special handling added for the TMA descriptors (in the Python wrapper codegen) during the compile-time autotuning, as the underlying tensor can't be passed directly to the user-defined Triton kernel. TMA descriptors are generated in-between the source tensor's buffer and the kernel call, like in the full Python wrapper codegen.
- This PR concludes the host-side Triton TMA support in PT2.

Pull Request resolved: pytorch#138878
Approved by: https://github.com/desertfire, https://github.com/chenyang78
ghstack dependencies: pytorch#138759, pytorch#138877
@github-actions github-actions bot deleted the gh/aakhundov/14/head branch November 28, 2024 02:13
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants