KEMBAR78
[MPS] Add `grid_sampler_3d` for MPS by kurtamohler · Pull Request #160541 · pytorch/pytorch · GitHub
Skip to content

Conversation

@kurtamohler
Copy link
Collaborator

@kurtamohler kurtamohler commented Aug 13, 2025

Stack from ghstack (oldest at bottom):

This PR adds support for grid_sampler_3d for MPS with "bilinear" interpolation.

NOTE: "nearest" interpolation is not yet supported

Fixes #159882

[ghstack-poisoned]
@pytorch-bot
Copy link

pytorch-bot bot commented Aug 13, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/160541

Note: Links to docs will display an error until the docs builds have been completed.

❌ 1 New Failure

As of commit 026d903 with merge base 3650989 (image):

NEW FAILURE - The following job has failed:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@pytorch-bot pytorch-bot bot added the ciflow/mps Run MPS tests (subset of trunk) label Aug 13, 2025
kurtamohler added a commit that referenced this pull request Aug 13, 2025
ghstack-source-id: db4ecbc
Pull-Request: #160541
@pytorch-bot pytorch-bot bot added the release notes: mps Release notes category label Aug 13, 2025
@github-actions
Copy link
Contributor

Attention! native_functions.yaml was changed

If you are adding a new function or defaulted argument to native_functions.yaml, you cannot use it from pre-existing Python frontend code until our FC window passes (two weeks). Split your PR into two PRs, one which adds the new C++ functionality, and one that makes use of it from Python, and land them two weeks apart. See https://github.com/pytorch/pytorch/wiki/PyTorch's-Python-Frontend-Backward-and-Forward-Compatibility-Policy#forwards-compatibility-fc for more info.


Caused by:

[ghstack-poisoned]
kurtamohler added a commit that referenced this pull request Aug 13, 2025
ghstack-source-id: b9ed053
Pull-Request: #160541
@kurtamohler
Copy link
Collaborator Author

kurtamohler commented Aug 14, 2025

I tried using opmath_t to fix the difference between CPU and MPS results for half types. That alone did not fix it. But then I wrote a test that compares the results of half and full precision types both on MPS, and those do match within a relatively low tolerance.

Then if I try removing the opmath_t, the half and full precision results on MPS become significantly different from each other again. And it looks like the CPU impl does not use opmath_t, so since each output element is calculated with a fairly large number of dependent multiplications and additions, I believe that explains why the CPU and MPS results differ greatly for half precision.

So I'll update the PR to use opmath_t and add the half-vs-full precision test I wrote. We'll have to continue skipping the CPU-vs-MPS test for half precision, unless the CPU impl is updated at some point to use opmath_t

[ghstack-poisoned]
kurtamohler added a commit that referenced this pull request Aug 14, 2025
ghstack-source-id: 767ed96
Pull-Request: #160541
[ghstack-poisoned]
kurtamohler added a commit that referenced this pull request Aug 14, 2025
ghstack-source-id: 89f1d84
Pull-Request: #160541
Copy link
Contributor

@malfet malfet left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@malfet
Copy link
Contributor

malfet commented Aug 15, 2025

@pytorchbot merge -f "Lint + MPS are green"

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged immediately since you used the force (-f) flag, bypassing any CI checks (ETA: 1-5 minutes). Please use -f as last resort and instead consider -i/--ignore-current to continue the merge ignoring current failures. This will allow currently pending tests to finish and report signal before the merge.

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

malfet added a commit that referenced this pull request Aug 18, 2025
This fixes following warnings during the compilation of GridSampler.metal
```
/Users/malfet/git/pytorch/pytorch/aten/src/ATen/native/mps/kernels/GridSampler.metal:22:23: warning: unused parameter 'input_sizes' [-Wunused-parameter]
    constant int32_t* input_sizes,
                      ^
/Users/malfet/git/pytorch/pytorch/aten/src/ATen/native/mps/kernels/GridSampler.metal:24:23: warning: unused parameter 'grid_sizes' [-Wunused-parameter]
    constant int32_t* grid_sizes,
                      ^
2 warnings generated.
```

Introduced by #160541

[ghstack-poisoned]
malfet added a commit that referenced this pull request Aug 18, 2025
This fixes following warnings during the compilation of GridSampler.metal
```
/Users/malfet/git/pytorch/pytorch/aten/src/ATen/native/mps/kernels/GridSampler.metal:22:23: warning: unused parameter 'input_sizes' [-Wunused-parameter]
    constant int32_t* input_sizes,
                      ^
/Users/malfet/git/pytorch/pytorch/aten/src/ATen/native/mps/kernels/GridSampler.metal:24:23: warning: unused parameter 'grid_sizes' [-Wunused-parameter]
    constant int32_t* grid_sizes,
                      ^
2 warnings generated.
```

Introduced by #160541

ghstack-source-id: fa913d3
Pull Request resolved: #160850
pytorchmergebot pushed a commit that referenced this pull request Aug 18, 2025
This fixes following warnings during the compilation of GridSampler.metal
```
/Users/malfet/git/pytorch/pytorch/aten/src/ATen/native/mps/kernels/GridSampler.metal:22:23: warning: unused parameter 'input_sizes' [-Wunused-parameter]
    constant int32_t* input_sizes,
                      ^
/Users/malfet/git/pytorch/pytorch/aten/src/ATen/native/mps/kernels/GridSampler.metal:24:23: warning: unused parameter 'grid_sizes' [-Wunused-parameter]
    constant int32_t* grid_sizes,
                      ^
2 warnings generated.
```

Introduced by #160541
Pull Request resolved: #160850
Approved by: https://github.com/cyyever, https://github.com/Skylion007
can-gaa-hou pushed a commit to can-gaa-hou/pytorch that referenced this pull request Aug 22, 2025
This PR adds support for `grid_sampler_3d` for MPS with "bilinear" interpolation.

NOTE: "nearest" interpolation is not yet supported

Fixes pytorch#159882
Pull Request resolved: pytorch#160541
Approved by: https://github.com/malfet
can-gaa-hou pushed a commit to can-gaa-hou/pytorch that referenced this pull request Aug 22, 2025
This fixes following warnings during the compilation of GridSampler.metal
```
/Users/malfet/git/pytorch/pytorch/aten/src/ATen/native/mps/kernels/GridSampler.metal:22:23: warning: unused parameter 'input_sizes' [-Wunused-parameter]
    constant int32_t* input_sizes,
                      ^
/Users/malfet/git/pytorch/pytorch/aten/src/ATen/native/mps/kernels/GridSampler.metal:24:23: warning: unused parameter 'grid_sizes' [-Wunused-parameter]
    constant int32_t* grid_sizes,
                      ^
2 warnings generated.
```

Introduced by pytorch#160541
Pull Request resolved: pytorch#160850
Approved by: https://github.com/cyyever, https://github.com/Skylion007
@github-actions github-actions bot deleted the gh/kurtamohler/46/head branch September 15, 2025 02:15
markc-614 pushed a commit to markc-614/pytorch that referenced this pull request Sep 17, 2025
This PR adds support for `grid_sampler_3d` for MPS with "bilinear" interpolation.

NOTE: "nearest" interpolation is not yet supported

Fixes pytorch#159882
Pull Request resolved: pytorch#160541
Approved by: https://github.com/malfet
markc-614 pushed a commit to markc-614/pytorch that referenced this pull request Sep 17, 2025
This fixes following warnings during the compilation of GridSampler.metal
```
/Users/malfet/git/pytorch/pytorch/aten/src/ATen/native/mps/kernels/GridSampler.metal:22:23: warning: unused parameter 'input_sizes' [-Wunused-parameter]
    constant int32_t* input_sizes,
                      ^
/Users/malfet/git/pytorch/pytorch/aten/src/ATen/native/mps/kernels/GridSampler.metal:24:23: warning: unused parameter 'grid_sizes' [-Wunused-parameter]
    constant int32_t* grid_sizes,
                      ^
2 warnings generated.
```

Introduced by pytorch#160541
Pull Request resolved: pytorch#160850
Approved by: https://github.com/cyyever, https://github.com/Skylion007
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ciflow/mps Run MPS tests (subset of trunk) Merged open source release notes: mps Release notes category

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants