KEMBAR78
[FlexAttention] Fix max-autotune bug with captured buffer grads by drisspg · Pull Request #141531 · pytorch/pytorch · GitHub
Skip to content

Conversation

@drisspg
Copy link
Contributor

@drisspg drisspg commented Nov 26, 2024

Stack from ghstack (oldest at bottom):

Summary

Fix tensor argument ordering for autotuning flex attention, change how we enabled scatters codegen for triton. We used to go through the existing store_output triton codegen but now we just short circuit and generate the correct expression earlier on.

This enables us to instead of relying on arg.python_defs to thread arguments through via input_buffers we can instead reuse the exact same mutated buffer infra as we did for multiple outputs before.

Test cases added for both default and max-autotune-no-cudagraphs modes.

cc @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @ipiszy @yf225 @chenyang78 @kadeng @muchulee8 @ColinPeppler @amjames @desertfire @chauhang @aakhundov @Chillee @yanboliang @BoyuanFeng

[ghstack-poisoned]
@pytorch-bot
Copy link

pytorch-bot bot commented Nov 26, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/141531

Note: Links to docs will display an error until the docs builds have been completed.

✅ You can merge normally! (1 Unrelated Failure)

As of commit fb56ad9 with merge base f472b3a (image):

UNSTABLE - The following job failed but was likely due to flakiness present on trunk and has been marked as unstable:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

drisspg added a commit that referenced this pull request Nov 26, 2024
@drisspg drisspg requested a review from Chillee November 26, 2024 01:01
@drisspg drisspg added topic: not user facing topic category ciflow/trunk Trigger trunk jobs on your pull request module: flex attention labels Nov 26, 2024
[ghstack-poisoned]
drisspg added a commit that referenced this pull request Nov 27, 2024
@drisspg
Copy link
Contributor Author

drisspg commented Nov 27, 2024

The problem here is that we are passing in the node an input That in the non-autotune case ends up being the last buffer argument while generating the args in max-autotune it ends up being the second to last because we have an explicit output node. I need to figure out how to reorder the output nodes from the kernel.

Confirmed by doing a hacky swap in in autotune process:

        input_tensors = list(input_tensors)
        tmp = input_tensors[-1]
        input_tensors[-1] = output_tensor
        output_tensor = tmp

[ghstack-poisoned]
drisspg added a commit that referenced this pull request Nov 27, 2024
[ghstack-poisoned]
drisspg added a commit that referenced this pull request Dec 3, 2024
[ghstack-poisoned]
drisspg added a commit that referenced this pull request Dec 3, 2024
[ghstack-poisoned]
drisspg added a commit that referenced this pull request Dec 4, 2024
[ghstack-poisoned]
drisspg added a commit that referenced this pull request Dec 4, 2024
Copy link
Collaborator

@Chillee Chillee left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could we add a couple more tests? Specifically I'd like a test with multiple captured grads.

@drisspg
Copy link
Contributor Author

drisspg commented Dec 4, 2024

Yeah, we have this test for default compile, but can add for autotuning

[ghstack-poisoned]
drisspg added a commit that referenced this pull request Dec 4, 2024
@drisspg
Copy link
Contributor Author

drisspg commented Dec 4, 2024

@pytorchbot merge

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

@pytorchmergebot
Copy link
Collaborator

Merge failed

Reason: 1 mandatory check(s) failed. The first few are:

Dig deeper by viewing the failures on hud

Details for Dev Infra team Raised by workflow job

Failing merge rule: Core Maintainers

@drisspg
Copy link
Contributor Author

drisspg commented Dec 4, 2024

@pytorchbot merge

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

pobin6 pushed a commit to pobin6/pytorch that referenced this pull request Dec 5, 2024
…rch#141531)

# Summary
Fix tensor argument ordering for autotuning flex attention, change how we enabled scatters codegen for triton. We used to go through the existing store_output triton codegen but now we just short circuit and generate the correct expression earlier on.

This enables us to instead of relying on arg.python_defs to thread arguments through via input_buffers we can instead reuse the exact same mutated buffer infra as we did for multiple outputs before.

Test cases added for both default and max-autotune-no-cudagraphs modes.

Pull Request resolved: pytorch#141531
Approved by: https://github.com/Chillee
AmdSampsa pushed a commit to AmdSampsa/pytorch that referenced this pull request Dec 9, 2024
…rch#141531)

# Summary
Fix tensor argument ordering for autotuning flex attention, change how we enabled scatters codegen for triton. We used to go through the existing store_output triton codegen but now we just short circuit and generate the correct expression earlier on.

This enables us to instead of relying on arg.python_defs to thread arguments through via input_buffers we can instead reuse the exact same mutated buffer infra as we did for multiple outputs before.

Test cases added for both default and max-autotune-no-cudagraphs modes.

Pull Request resolved: pytorch#141531
Approved by: https://github.com/Chillee
@github-actions github-actions bot deleted the gh/drisspg/87/head branch January 4, 2025 02:04
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants