KEMBAR78
[MPS] Improve performance of max_pool3d by kurtamohler · Pull Request #157875 · pytorch/pytorch · GitHub
Skip to content

Conversation

@kurtamohler
Copy link
Collaborator

@kurtamohler kurtamohler commented Jul 9, 2025

Stack from ghstack (oldest at bottom):

To check how the changes from this PR affect performance, I wrote a script here: https://github.com/kurtamohler/pytorch-perf-test-scripts/blob/55ef32a127c746d13d7310375068a6b300bda92d/max_pool_mps/perf.py.

Before this PR, I get this:

===================
max_pool3d
===================
0: 0.013105 ms, max_pool3d, (3, 2, 2, 2), {'kernel_size': 2}
1: 0.038003 ms, max_pool3d, (3, 10, 10, 10), {'kernel_size': 5}
2: 0.212963 ms, max_pool3d, (3, 100, 100, 100), {'kernel_size': 5}
3: 1.224645 ms, max_pool3d, (3, 200, 200, 200), {'kernel_size': 5}
4: 7.317867 ms, max_pool3d, (10, 10, 100, 100, 100), {'kernel_size': 4, 'padding': 1}
5: 34.679233 ms, max_pool3d, (10, 10, 100, 100, 100), {'kernel_size': 50, 'padding': 20}
6: 34.626383 ms, max_pool3d, (10, 10, 100, 100, 100), {'kernel_size': 50, 'padding': 20, 'dilation': 1}
7: 44.835892 ms, max_pool3d, (10, 10, 100, 100, 100), {'kernel_size': 50, 'padding': 20, 'dilation': 1, 'stride': 40}
8: 0.083579 ms, max_pool3d, (10, 10, 10, 10, 10), {'kernel_size': 2}
9: 0.936575 ms, max_pool3d, (10, 10, 30, 30, 30), {'kernel_size': 2}
10: 5.329883 ms, max_pool3d, (10, 10, 50, 50, 50), {'kernel_size': 2}
11: 11.713617 ms, max_pool3d, (10, 10, 70, 70, 70), {'kernel_size': 2}
12: 25.450454 ms, max_pool3d, (10, 10, 90, 90, 90), {'kernel_size': 2}
13: 0.058375 ms, max_pool3d, (10, 10, 10, 10, 10), {'kernel_size': 2, 'dilation': 2}
14: 3.757558 ms, max_pool3d, (10, 10, 50, 50, 50), {'kernel_size': 2, 'dilation': 2}
15: 33.451588 ms, max_pool3d, (10, 10, 100, 100, 100), {'kernel_size': 2, 'dilation': 2}

After this PR, I get this:

===================
max_pool3d
===================
0: 0.007202 ms, max_pool3d, (3, 2, 2, 2), {'kernel_size': 2}
1: 0.018596 ms, max_pool3d, (3, 10, 10, 10), {'kernel_size': 5}
2: 0.130717 ms, max_pool3d, (3, 100, 100, 100), {'kernel_size': 5}
3: 0.966795 ms, max_pool3d, (3, 200, 200, 200), {'kernel_size': 5}
4: 4.095804 ms, max_pool3d, (10, 10, 100, 100, 100), {'kernel_size': 4, 'padding': 1}
5: 12.833446 ms, max_pool3d, (10, 10, 100, 100, 100), {'kernel_size': 50, 'padding': 20}
6: 12.859346 ms, max_pool3d, (10, 10, 100, 100, 100), {'kernel_size': 50, 'padding': 20, 'dilation': 1}
7: 14.080529 ms, max_pool3d, (10, 10, 100, 100, 100), {'kernel_size': 50, 'padding': 20, 'dilation': 1, 'stride': 40}
8: 0.029283 ms, max_pool3d, (10, 10, 10, 10, 10), {'kernel_size': 2}
9: 0.175700 ms, max_pool3d, (10, 10, 30, 30, 30), {'kernel_size': 2}
10: 0.742750 ms, max_pool3d, (10, 10, 50, 50, 50), {'kernel_size': 2}
11: 1.939596 ms, max_pool3d, (10, 10, 70, 70, 70), {'kernel_size': 2}
12: 4.074821 ms, max_pool3d, (10, 10, 90, 90, 90), {'kernel_size': 2}
13: 0.028425 ms, max_pool3d, (10, 10, 10, 10, 10), {'kernel_size': 2, 'dilation': 2}
14: 0.384375 ms, max_pool3d, (10, 10, 50, 50, 50), {'kernel_size': 2, 'dilation': 2}
15: 2.623346 ms, max_pool3d, (10, 10, 100, 100, 100), {'kernel_size': 2, 'dilation': 2}

Every case is improved.

[ghstack-poisoned]
@pytorch-bot
Copy link

pytorch-bot bot commented Jul 9, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/157875

Note: Links to docs will display an error until the docs builds have been completed.

✅ You can merge normally! (1 Unrelated Failure)

As of commit 09330ba with merge base ee09928 (image):

UNSTABLE - The following job is marked as unstable, possibly due to flakiness on trunk:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

[ghstack-poisoned]
kurtamohler added a commit to kurtamohler/pytorch that referenced this pull request Jul 9, 2025
@kurtamohler
Copy link
Collaborator Author

To check how the changes from this PR and #157874 affect performance, I wrote a script here: https://github.com/kurtamohler/pytorch-perf-test-scripts/blob/55ef32a127c746d13d7310375068a6b300bda92d/max_pool_mps/perf.py

Before these two PRs, I get the following median run times for various cases:

===================
max_pool3d
===================
0: 0.020005 ms, max_pool3d, (3, 2, 2, 2), {'kernel_size': 2}
1: 0.047002 ms, max_pool3d, (3, 10, 10, 10), {'kernel_size': 5}
2: 0.224255 ms, max_pool3d, (3, 100, 100, 100), {'kernel_size': 5}
3: 1.212816 ms, max_pool3d, (3, 200, 200, 200), {'kernel_size': 5}
4: 7.184867 ms, max_pool3d, (10, 10, 100, 100, 100), {'kernel_size': 4, 'padding': 1}
5: 34.662742 ms, max_pool3d, (10, 10, 100, 100, 100), {'kernel_size': 50, 'padding': 20}
6: 34.658600 ms, max_pool3d, (10, 10, 100, 100, 100), {'kernel_size': 50, 'padding': 20, 'dilation': 1}
7: 44.815712 ms, max_pool3d, (10, 10, 100, 100, 100), {'kernel_size': 50, 'padding': 20, 'dilation': 1, 'stride': 40}
8: 0.109283 ms, max_pool3d, (10, 10, 10, 10, 10), {'kernel_size': 2}
9: 0.943642 ms, max_pool3d, (10, 10, 30, 30, 30), {'kernel_size': 2}
10: 4.420575 ms, max_pool3d, (10, 10, 50, 50, 50), {'kernel_size': 2}
11: 11.855496 ms, max_pool3d, (10, 10, 70, 70, 70), {'kernel_size': 2}
12: 26.128042 ms, max_pool3d, (10, 10, 90, 90, 90), {'kernel_size': 2}
13: 0.069175 ms, max_pool3d, (10, 10, 10, 10, 10), {'kernel_size': 2, 'dilation': 2}
14: 3.790779 ms, max_pool3d, (10, 10, 50, 50, 50), {'kernel_size': 2, 'dilation': 2}
15: 33.308371 ms, max_pool3d, (10, 10, 100, 100, 100), {'kernel_size': 2, 'dilation': 2}

After applying these two PRs, I get the following measurements:

===================
max_pool3d
===================
0: 0.008978 ms, max_pool3d, (3, 2, 2, 2), {'kernel_size': 2}
1: 0.027260 ms, max_pool3d, (3, 10, 10, 10), {'kernel_size': 5}
2: 0.143246 ms, max_pool3d, (3, 100, 100, 100), {'kernel_size': 5}
3: 0.952385 ms, max_pool3d, (3, 200, 200, 200), {'kernel_size': 5}
4: 4.052346 ms, max_pool3d, (10, 10, 100, 100, 100), {'kernel_size': 4, 'padding': 1}
5: 21.206533 ms, max_pool3d, (10, 10, 100, 100, 100), {'kernel_size': 50, 'padding': 20}
6: 21.219742 ms, max_pool3d, (10, 10, 100, 100, 100), {'kernel_size': 50, 'padding': 20, 'dilation': 1}
7: 22.074067 ms, max_pool3d, (10, 10, 100, 100, 100), {'kernel_size': 50, 'padding': 20, 'dilation': 1, 'stride': 40}
8: 0.035204 ms, max_pool3d, (10, 10, 10, 10, 10), {'kernel_size': 2}
9: 0.182700 ms, max_pool3d, (10, 10, 30, 30, 30), {'kernel_size': 2}
10: 0.744158 ms, max_pool3d, (10, 10, 50, 50, 50), {'kernel_size': 2}
11: 1.937175 ms, max_pool3d, (10, 10, 70, 70, 70), {'kernel_size': 2}
12: 4.063979 ms, max_pool3d, (10, 10, 90, 90, 90), {'kernel_size': 2}
13: 0.030263 ms, max_pool3d, (10, 10, 10, 10, 10), {'kernel_size': 2, 'dilation': 2}
14: 0.503246 ms, max_pool3d, (10, 10, 50, 50, 50), {'kernel_size': 2, 'dilation': 2}
15: 3.905775 ms, max_pool3d, (10, 10, 100, 100, 100), {'kernel_size': 2, 'dilation': 2}

Every one of the cases is improved

[ghstack-poisoned]
kurtamohler added a commit to kurtamohler/pytorch that referenced this pull request Jul 9, 2025
Copy link
Contributor

@malfet malfet left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please avoid hardcoding types (unless you need it for some reason) but instead rely on compiler to infer it. This way, when you decide to change the type later, you don't have to worry about changing it throughout codebase

Also, use Metal's native uint3/uint4 data types to pass fixed sized arrays around (which can be accessed both by index or by .x/.y suffix)

constant int32_t* input_sizes,
constant int32_t* input_strides,
thread int32_t (&work_pooling_dim_indices)[3],
constant int32_t* kernel_size,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Have you considered using vectorized types? They result in the faster code and kind of look more natural

Suggested change
constant int32_t* kernel_size,
constant uint3& kernel_size,

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I tried using vectorized types with the following diff, but it seemingly did not change the performance at all. I'm guessing that it compiles to something that is equivalent. I opted not to put it in the PR because I think it makes the code a little less readable. But if you still prefer it, I can put it in this PR.

Click to expand
diff --git a/aten/src/ATen/native/mps/kernels/Pooling.metal b/aten/src/ATen/native/mps/kernels/Pooling.metal
index 18982559a34..53a4bf8f2a7 100644
--- a/aten/src/ATen/native/mps/kernels/Pooling.metal
+++ b/aten/src/ATen/native/mps/kernels/Pooling.metal
@@ -28,6 +28,20 @@ IterBounds<int32_t> get_input_iter_bounds(
   return IterBounds<int32_t>{start, end};
 }
 
+IterBounds<int3> get_input_iter_bounds(
+    constant int3& input_sizes,
+    thread int3& pooling_dim_indices,
+    constant int3& kernel_size,
+    constant int3& stride,
+    constant int3& padding,
+    constant int3& dilation) {
+  auto start = stride * pooling_dim_indices - padding;
+  auto end = min(start + kernel_size * dilation, input_sizes);
+  auto start_correction = dilation * ((-start - 1 + dilation) / dilation);
+  start += select(int3(0), start_correction, start < int3(0));
+  return IterBounds<int3>{start, end};
+}
+
 // Iterates through all the input elements that this kernel needs to
 // apply max to. Specialized for 3 pooling dimensions.
 // TODO: Support any number of pooling dims
@@ -88,6 +102,55 @@ void max_pool_3d_input_iter(
   }
 }
 
+template <typename T>
+void max_pool_3d_input_iter(
+    constant T* input,
+    device T* output,
+    device int64_t* indices,
+    constant int3& input_sizes,
+    constant int3& input_strides,
+    thread int3& pooling_dim_indices,
+    constant int3& kernel_size,
+    constant int3& stride,
+    constant int3& padding,
+    constant int3& dilation,
+    bool return_indices) {
+  auto bounds = get_input_iter_bounds(
+      input_sizes, pooling_dim_indices, kernel_size, stride, padding, dilation);
+
+  T max_value = input
+      [input_strides.x * bounds.start.x + input_strides.y * bounds.start.y +
+       input_strides.z * bounds.start.z];
+  auto size12 = input_sizes.y * input_sizes.z;
+  auto max_index =
+      bounds.start.x * size12 + bounds.start.y * input_sizes.z + bounds.start.z;
+
+  for (auto i0 = bounds.start.x; i0 < bounds.end.x; i0 += dilation.x) {
+    auto offset0 = input_strides.x * i0;
+
+    for (auto i1 = bounds.start.y; i1 < bounds.end.y; i1 += dilation.y) {
+      auto offset1 = input_strides.y * i1;
+
+      for (auto i2 = bounds.start.z; i2 < bounds.end.z; i2 += dilation.z) {
+        auto offset2 = input_strides.z * i2;
+        auto input_value = input[offset0 + offset1 + offset2];
+        bool is_greater = input_value > max_value;
+
+        max_value = is_greater ? input_value : max_value;
+
+        if (return_indices) {
+          auto input_index = i0 * size12 + i1 * input_sizes.z + i2;
+          max_index = is_greater ? input_index : max_index;
+        }
+      }
+    }
+  }
+  *output = max_value;
+  if (return_indices) {
+    *indices = max_index;
+  }
+}
+
 struct PoolOffsets {
   int32_t output;
   int32_t indices;
@@ -212,18 +275,41 @@ kernel void max_pool(
   output += offsets.output;
   indices += offsets.indices;
   input += offsets.input_leading;
+  input_sizes += leading_dims;
+  input_strides += leading_dims;
+
+  constant int3& input_sizes_ = *reinterpret_cast<constant int3*>(input_sizes);
+  constant int3& input_strides_ = *reinterpret_cast<constant int3*>(input_strides);
+  thread int3& pooling_dim_indices_ = *reinterpret_cast<thread int3*>(pooling_dim_indices);
+  constant int3& kernel_size_ = *reinterpret_cast<constant int3*>(kernel_size);
+  constant int3& stride_ = *reinterpret_cast<constant int3*>(stride);
+  constant int3& padding_ = *reinterpret_cast<constant int3*>(padding);
+  constant int3& dilation_ = *reinterpret_cast<constant int3*>(dilation);
+
+  //max_pool_3d_input_iter<T>(
+  //    input,
+  //    output,
+  //    indices,
+  //    input_sizes,
+  //    input_strides,
+  //    pooling_dim_indices,
+  //    kernel_size,
+  //    stride,
+  //    padding,
+  //    dilation,
+  //    return_indices);
 
   max_pool_3d_input_iter<T>(
       input,
       output,
       indices,
-      input_sizes + leading_dims,
-      input_strides + leading_dims,
-      pooling_dim_indices,
-      kernel_size,
-      stride,
-      padding,
-      dilation,
+      input_sizes_,
+      input_strides_,
+      pooling_dim_indices_,
+      kernel_size_,
+      stride_,
+      padding_,
+      dilation_,
       return_indices);
 }

@malfet
Copy link
Contributor

malfet commented Jul 11, 2025

To check how the changes from this PR and #157874 affect performance, I wrote a script here: https://github.com/kurtamohler/pytorch-perf-test-scripts/blob/55ef32a127c746d13d7310375068a6b300bda92d/max_pool_mps/perf.py

Do you mind mentioning it for posterity in PR description, and may be also integrate it into test/perf_mps.py

[ghstack-poisoned]
kurtamohler added a commit to kurtamohler/pytorch that referenced this pull request Jul 15, 2025
@kurtamohler
Copy link
Collaborator Author

kurtamohler commented Jul 15, 2025

I've made a few more changes that improved performance, updated a bunch of places in the code to use auto, and I added a note about the performance improvement in the PR description

@kurtamohler
Copy link
Collaborator Author

@pytorchbot merge

@pytorch-bot pytorch-bot bot added the ciflow/trunk Trigger trunk jobs on your pull request label Jul 17, 2025
@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

@github-actions github-actions bot deleted the gh/kurtamohler/40/head branch August 17, 2025 02:19
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ciflow/mps Run MPS tests (subset of trunk) ciflow/trunk Trigger trunk jobs on your pull request Merged open source release notes: mps Release notes category

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants