KEMBAR78
PixelShuffle is very slow on MPS compared to ConvTranspose2d and PixelShuffle on cuda · Issue #83196 · pytorch/pytorch · GitHub
Skip to content

PixelShuffle is very slow on MPS compared to ConvTranspose2d and PixelShuffle on cuda #83196

@neil3706

Description

@neil3706

🐛 Describe the bug

I changed the decode layers of a CNN to use pixel shuffle instead of transpose_conv2d. The model immediately took 3x longer to run. GPU usage was still high and there was no warning about fallback to using the CPU. I recently starting training the CNN remotely using coda and saw a reduction in training by 66%. The only conclusion is that there is a problem with the MPS implementation.

The commented out lines show the ConvTranspose2D being replaced with the Conv2d followed by PixelShuffle

class decodeUpBlock(nn.Module):
def init(self, inFM, outFM, kSize, outSize, stride=1, padding=0, output_padding=0):
super(decodeUpBlock, self).init()
self.layers = nn.Sequential()
#self.layers.add_module("TConv", nn.ConvTranspose2d(inFM,outFM,kSize,stride=stride,padding=padding,output_padding=output_padding))
self.layers.add_module("Conv2", nn.Conv2d(inFM,4*outFM,kSize,stride=1,padding=padding))
self.layers.add_module("pShfl", nn.PixelShuffle(stride))
self.layers.add_module("lNorm", nn.LayerNorm([outFM,outSize,outSize]))
self.layers.add_module("ReLU", nn.LeakyReLU(0.2,inplace=True))

Versions

Collecting environment information...
PyTorch version: 1.12.0.post2
Is debug build: False
CUDA used to build PyTorch: None
ROCM used to build PyTorch: N/A

OS: macOS 13.0 (arm64)
GCC version: Could not collect
Clang version: 13.1.6 (clang-1316.0.21.2.5)
CMake version: Could not collect
Libc version: N/A

Python version: 3.10.5 | packaged by conda-forge | (main, Jun 14 2022, 07:07:06) [Clang 13.0.1 ] (64-bit runtime)
Python platform: macOS-13.0-arm64-arm-64bit
Is CUDA available: False
CUDA runtime version: No CUDA
GPU models and configuration: No CUDA
Nvidia driver version: No CUDA
cuDNN version: No CUDA
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True

Versions of relevant libraries:
[pip3] numpy==1.23.1
[pip3] torch==1.12.0.post2
[pip3] torchvision==0.13.0a0
[conda] numpy 1.23.1 py310h0a343b5_0 conda-forge
[conda] pytorch 1.12.0 cpu_py310h911b1ea_2 conda-forge
[conda] torchvision 0.13.0 cpu_py310he68663e_0 conda-forge

cc @VitalyFedyunin @ngimel @kulinseth @albanD

Metadata

Metadata

Assignees

Labels

module: mpsRelated to Apple Metal Performance Shaders frameworkmodule: performanceIssues related to performance, either of kernel code or framework gluetriagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate module

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions