-
Notifications
You must be signed in to change notification settings - Fork 25.7k
Description
🐛 Describe the bug
I changed the decode layers of a CNN to use pixel shuffle instead of transpose_conv2d. The model immediately took 3x longer to run. GPU usage was still high and there was no warning about fallback to using the CPU. I recently starting training the CNN remotely using coda and saw a reduction in training by 66%. The only conclusion is that there is a problem with the MPS implementation.
The commented out lines show the ConvTranspose2D being replaced with the Conv2d followed by PixelShuffle
class decodeUpBlock(nn.Module):
def init(self, inFM, outFM, kSize, outSize, stride=1, padding=0, output_padding=0):
super(decodeUpBlock, self).init()
self.layers = nn.Sequential()
#self.layers.add_module("TConv", nn.ConvTranspose2d(inFM,outFM,kSize,stride=stride,padding=padding,output_padding=output_padding))
self.layers.add_module("Conv2", nn.Conv2d(inFM,4*outFM,kSize,stride=1,padding=padding))
self.layers.add_module("pShfl", nn.PixelShuffle(stride))
self.layers.add_module("lNorm", nn.LayerNorm([outFM,outSize,outSize]))
self.layers.add_module("ReLU", nn.LeakyReLU(0.2,inplace=True))
Versions
Collecting environment information...
PyTorch version: 1.12.0.post2
Is debug build: False
CUDA used to build PyTorch: None
ROCM used to build PyTorch: N/A
OS: macOS 13.0 (arm64)
GCC version: Could not collect
Clang version: 13.1.6 (clang-1316.0.21.2.5)
CMake version: Could not collect
Libc version: N/A
Python version: 3.10.5 | packaged by conda-forge | (main, Jun 14 2022, 07:07:06) [Clang 13.0.1 ] (64-bit runtime)
Python platform: macOS-13.0-arm64-arm-64bit
Is CUDA available: False
CUDA runtime version: No CUDA
GPU models and configuration: No CUDA
Nvidia driver version: No CUDA
cuDNN version: No CUDA
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True
Versions of relevant libraries:
[pip3] numpy==1.23.1
[pip3] torch==1.12.0.post2
[pip3] torchvision==0.13.0a0
[conda] numpy 1.23.1 py310h0a343b5_0 conda-forge
[conda] pytorch 1.12.0 cpu_py310h911b1ea_2 conda-forge
[conda] torchvision 0.13.0 cpu_py310he68663e_0 conda-forge