KEMBAR78
[inductor][user triton] Check size hints to determine indexing dtype by davidberard98 · Pull Request #137234 · pytorch/pytorch · GitHub
Skip to content

Conversation

@davidberard98
Copy link
Contributor

@davidberard98 davidberard98 commented Oct 2, 2024

Stack from ghstack (oldest at bottom):

Previously, all integer inputs to user-defined triton kernels were assumed to be int32. This would result in errors if your input was actually an int64.

This PR checks the value to determine which dtype to use for indexing: if it is known to be < int_max, then use int32 (and add guards if relevant); if we can't check (e.g. unbacked symint), then use int64.

cc @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @ipiszy @yf225 @chenyang78 @kadeng @muchulee8 @ColinPeppler @amjames @desertfire @chauhang @oulgen @aakhundov

Differential Revision: D63797975

@pytorch-bot
Copy link

pytorch-bot bot commented Oct 2, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/137234

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit 5de8e4d with merge base 63bbf71 (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@davidberard98
Copy link
Contributor Author

@davidberard98 has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator.

…nt32 vs tl.int64 indexing"

cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy yf225 chenyang78 kadeng muchulee8 ColinPeppler amjames desertfire chauhang

Differential Revision: [D63797975](https://our.internmc.facebook.com/intern/diff/D63797975)

[ghstack-poisoned]
davidberard98 added a commit that referenced this pull request Oct 2, 2024
…int64 indexing

ghstack-source-id: a0ab2be
Pull Request resolved: #137234
@davidberard98
Copy link
Contributor Author

@davidberard98 has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator.

@davidberard98 davidberard98 changed the title [inductor][user triton] Check size hints to determine tl.int32 vs tl.int64 indexing [inductor][user triton] Check size hints to determine indexing dtype Oct 3, 2024
@davidberard98
Copy link
Contributor Author

@davidberard98 has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator.

@davidberard98 davidberard98 added the module: user triton related to ability to directly torch.compile triton kernels label Oct 3, 2024
@davidberard98 davidberard98 marked this pull request as ready for review October 3, 2024 16:49
elif isinstance(arg.expr, (float, sympy.Float)):
return "fp32"

# if this is a integer
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Super-nit: an integer :)

Comment on lines 1707 to 1709
seed = torch.randint(
low=0, high=2**62, size=(1,), dtype=torch.int64
).item()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we want to explicitly put this in the above int32 range, so that we don't get errors passing silently? Or maybe check both within and outside int32?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good point, updated it to [2**32, 2**62).

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for doing this, although I realized how insanely unlikely it is sample int32 from [2, 2**62] :)

…xing dtype"


Previously, all integer inputs to user-defined triton kernels were assumed to be int32. This would result in errors if your input was actually an int64.

This PR checks the value to determine which dtype to use for indexing: if it is known to be < int_max, then use int32 (and add guards if relevant); if we can't check (e.g. unbacked symint), then use int64.

cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy yf225 chenyang78 kadeng muchulee8 ColinPeppler amjames desertfire chauhang oulgen aakhundov

Differential Revision: [D63797975](https://our.internmc.facebook.com/intern/diff/D63797975)

[ghstack-poisoned]
…xing dtype"


Previously, all integer inputs to user-defined triton kernels were assumed to be int32. This would result in errors if your input was actually an int64.

This PR checks the value to determine which dtype to use for indexing: if it is known to be < int_max, then use int32 (and add guards if relevant); if we can't check (e.g. unbacked symint), then use int64.

cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy yf225 chenyang78 kadeng muchulee8 ColinPeppler amjames desertfire chauhang oulgen aakhundov

Differential Revision: [D63797975](https://our.internmc.facebook.com/intern/diff/D63797975)

[ghstack-poisoned]
davidberard98 added a commit that referenced this pull request Oct 3, 2024
…int64 indexing

ghstack-source-id: 9203c76
Pull Request resolved: #137234
@davidberard98
Copy link
Contributor Author

@davidberard98 has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator.

@pytorch-bot pytorch-bot bot added the ciflow/trunk Trigger trunk jobs on your pull request label Oct 3, 2024
@davidberard98
Copy link
Contributor Author

@pytorchbot merge

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

@github-actions github-actions bot deleted the gh/davidberard98/340/head branch November 6, 2024 02:09
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ciflow/inductor ciflow/trunk Trigger trunk jobs on your pull request Merged module: inductor module: user triton related to ability to directly torch.compile triton kernels release notes: inductor topic: bug fixes topic category

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants