KEMBAR78
[MPS] Add i0 op by malfet · Pull Request #137849 · pytorch/pytorch · GitHub
Skip to content

Conversation

@malfet
Copy link
Contributor

@malfet malfet commented Oct 12, 2024

More-or-less verbatim copy of

JITERATOR_HOST_DEVICE T calc_i0e(T _x) {

Plus a bit of a MPS boilerplate code

Update test_mps.py to mark kaiser_window and i0 as passing

@malfet malfet requested a review from kulinseth as a code owner October 12, 2024 20:34
@pytorch-bot
Copy link

pytorch-bot bot commented Oct 12, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/137849

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit af0f1f2 with merge base bc232e3 (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@pytorch-bot pytorch-bot bot added ciflow/mps Run MPS tests (subset of trunk) release notes: mps Release notes category labels Oct 12, 2024
@github-actions
Copy link
Contributor

Attention! native_functions.yaml was changed

If you are adding a new function or defaulted argument to native_functions.yaml, you cannot use it from pre-existing Python frontend code until our FC window passes (two weeks). Split your PR into two PRs, one which adds the new C++ functionality, and one that makes use of it from Python, and land them two weeks apart. See https://github.com/pytorch/pytorch/wiki/PyTorch's-Python-Frontend-Backward-and-Forward-Compatibility-Policy#forwards-compatibility-fc for more info.


Caused by:

@malfet malfet added the topic: improvements topic category label Oct 12, 2024
6.88975834691682398426E-5, 3.36911647825569408990E-3,
8.04490411014108831608E-1};

return static_cast<T>((exp(x) * chbevl(32.0 / x - 2.0, B, 25)) / sqrt(x));
Copy link
Collaborator

@Skylion007 Skylion007 Oct 13, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
return static_cast<T>((exp(x) * chbevl(32.0 / x - 2.0, B, 25)) / sqrt(x));
return static_cast<T>((exp(x) * chbevl(32.0 / x - 2.0, B, 25)) * rsqrt(x));

Not sure if this affects the numeric, but it should have a special rsqrt instruction we can exploit

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Compiler should be able to figure out that 1.0 / sqrt could be replaced with rsqrt, but sure multiplication of 3 elements indeed sounds nice

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually shouldn't be able to since the numerics differ ever so slightly apparently.

malfet added a commit that referenced this pull request Oct 14, 2024
To match behavior for torch.special.i0 

Noticed while looking at the failures in #137849
@malfet
Copy link
Contributor Author

malfet commented Oct 14, 2024

@pytorchbot merge -f "Lint + MPS are green"

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged immediately since you used the force (-f) flag, bypassing any CI checks (ETA: 1-5 minutes). Please use -f as last resort and instead consider -i/--ignore-current to continue the merge ignoring current failures. This will allow currently pending tests to finish and report signal before the merge.

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

pytorchmergebot pushed a commit that referenced this pull request Oct 15, 2024
To match behavior of `torch.special.i0`

Noticed while looking at the failures in #137849

Also, add explicit high-precision template specialization for  `calc_i0` and `calc_i1` for `BFloat16` and `Half`

Pull Request resolved: #137899
Approved by: https://github.com/Skylion007
@github-actions github-actions bot deleted the malfet/mps-add-i0-op branch November 14, 2024 02:07
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ciflow/mps Run MPS tests (subset of trunk) Merged release notes: mps Release notes category topic: improvements topic category

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants