KEMBAR78
Support `torch.concat` alias, add `cat` OpInfo & remove OpInfo test_out skips {cat, stack, hstack, vtack, dstack} by AnirudhDagar · Pull Request #62560 · pytorch/pytorch · GitHub
Skip to content

Conversation

@AnirudhDagar
Copy link
Contributor

@AnirudhDagar AnirudhDagar commented Aug 2, 2021

Fixes #61767

Changes

  • Add torch.concat alias to torch.cat
  • Add OpInfo for cat/concat
  • Fix test_out skips (Use at::native::resize_output or at::native::resize_output_check)
    • cat/concat
    • stack
    • hstack
    • dstack
    • vstack/row_stack
  • Remove redundant tests for cat/stack

I've not added cat/concat to OpInfo op_db yet, since cat is a little more tricky than other OpInfos (should have a lot of tests) and currently there are no OpInfos for that. I can try to add that in a subsequent PR or maybe here itself, whatever is suggested.
Edit: cat/concat OpInfo has been added.

Note: I've added the named tensor support for concat alias as well, maybe that's out of spec in array-api but it is still useful for consistency in PyTorch.

Thanks to @krshrimali for guidance on my first PR :))

cc @mruberry @rgommers @pmeier @asmeurer @leofang @AnirudhDagar @asi1024 @emcastillo @kmaehashi @heitorschueroff @krshrimali

@AnirudhDagar AnirudhDagar requested a review from ezyang as a code owner August 2, 2021 07:16
@facebook-github-bot facebook-github-bot added oncall: jit Add this issue/PR to JIT oncall triage queue cla signed labels Aug 2, 2021
@facebook-github-bot
Copy link
Contributor

facebook-github-bot commented Aug 2, 2021

🔗 Helpful links

💊 CI failures summary and remediations

As of commit 95f9736 (more details on the Dr. CI page):


  • 3/3 failures introduced in this PR

🕵️ 3 new failures recognized by patterns

The following CI failures do not appear to be due to upstream breakages:

See CircleCI build pytorch_linux_xenial_cuda11_1_cudnn8_py3_gcc7_build (1/3)

Step: "Build" (full log | diagnosis details | 🔁 rerun)

Aug 26 06:53:03 ERROR 2021-08-26T02:16:30Z: scc...eof ((socklen_t)))\n ^\n" }
Aug 26 06:53:03 ERROR 2021-08-26T02:16:23Z: sccache::server: Compilation failed: Output { status: ExitStatus(ExitStatus(256)), stdout: "", stderr: "conftest.c: In function \'main\':\nconftest.c:332:2: error: \'struct sockaddr\' has no member named \'sa_len\'\n x.sa_len = 0;\n  ^\n" }
Aug 26 06:53:03 
Aug 26 06:53:03 ERROR 2021-08-26T02:16:26Z: sccache::server: Compilation failed: Output { status: ExitStatus(ExitStatus(256)), stdout: "", stderr: "conftest.c: In function \'main\':\nconftest.c:366:10: error: \'RTLD_MEMBER\' undeclared (first use in this function); did you mean \'RTLD_NEXT\'?\n   (void) RTLD_MEMBER;\n          ^~~~~~~~~~~\n          RTLD_NEXT\nconftest.c:366:10: note: each undeclared identifier is reported only once for each function it appears in\n" }
Aug 26 06:53:03 
Aug 26 06:53:03 ERROR 2021-08-26T02:16:27Z: sccache::server: Compilation failed: Output { status: ExitStatus(ExitStatus(256)), stdout: "", stderr: "conftest.c:361:9: error: unknown type name \'not\'\n         not a universal capable compiler\n         ^~~\nconftest.c:361:15: error: expected \'=\', \',\', \';\', \'asm\' or \'__attribute__\' before \'universal\'\n         not a universal capable compiler\n               ^~~~~~~~~\nconftest.c:361:15: error: unknown type name \'universal\'\n" }
Aug 26 06:53:03 
Aug 26 06:53:03 ERROR 2021-08-26T02:16:27Z: sccache::server: Compilation failed: Output { status: ExitStatus(ExitStatus(256)), stdout: "", stderr: "conftest.c: In function \'main\':\nconftest.c:367:4: error: unknown type name \'not\'; did you mean \'ino_t\'?\n    not big endian\n    ^~~\n    ino_t\nconftest.c:367:12: error: expected \'=\', \',\', \';\', \'asm\' or \'__attribute__\' before \'endian\'\n    not big endian\n            ^~~~~~\n" }
Aug 26 06:53:03 
Aug 26 06:53:03 ERROR 2021-08-26T02:16:28Z: sccache::server: Compilation failed: Output { status: ExitStatus(ExitStatus(256)), stdout: "", stderr: "conftest.c: In function \'main\':\nconftest.c:378:4: error: \'struct stat\' has no member named \'st_mtimespec\'; did you mean \'st_mtim\'?\n st.st_mtimespec.tv_nsec = 1;\n    ^~~~~~~~~~~~\n    st_mtim\n" }
Aug 26 06:53:03 
Aug 26 06:53:03 ERROR 2021-08-26T02:16:30Z: sccache::server: Compilation failed: Output { status: ExitStatus(ExitStatus(256)), stdout: "", stderr: "conftest.c: In function \'main\':\nconftest.c:402:24: error: expected expression before \')\' token\n if (sizeof ((socklen_t)))\n                        ^\n" }
Aug 26 06:53:03 
Aug 26 06:53:03 =========== If your build fails, please take a look at the log above for possible reasons ===========
Aug 26 06:53:03 Compile requests                   14032
Aug 26 06:53:03 Compile requests executed           8085
Aug 26 06:53:03 Cache hits                          6532
Aug 26 06:53:03 Cache hits (C/C++)                  6218
Aug 26 06:53:03 Cache hits (CUDA)                    314
Aug 26 06:53:03 Cache misses                        1479
Aug 26 06:53:03 Cache misses (C/C++)                1146
Aug 26 06:53:03 Cache misses (CUDA)                  333

See CircleCI build pytorch_linux_bionic_cuda10_2_cudnn7_py3_9_gcc7_build (2/3)

Step: "Build" (full log | diagnosis details | 🔁 rerun)

Aug 26 02:07:21 rm: cannot remove '/var/lib/jenkins/sccache_error.log': No such file or directory
Aug 26 02:07:21 ++++ extract_trap_cmd
Aug 26 02:07:21 ++++ printf '%s\n' ''
Aug 26 02:07:21 +++ printf '%s\n' cleanup
Aug 26 02:07:21 ++ trap -- '
Aug 26 02:07:21 cleanup' EXIT
Aug 26 02:07:21 ++ [[ pytorch-linux-bionic-cuda10.2-cudnn7-py3.9-gcc7-build != *win-* ]]
Aug 26 02:07:21 ++ which sccache
Aug 26 02:07:21 ++ sccache --stop-server
Aug 26 02:07:21 ++ true
Aug 26 02:07:21 ++ rm /var/lib/jenkins/sccache_error.log
Aug 26 02:07:21 rm: cannot remove '/var/lib/jenkins/sccache_error.log': No such file or directory
Aug 26 02:07:21 ++ true
Aug 26 02:07:21 ++ [[ -n '' ]]
Aug 26 02:07:21 ++ [[ pytorch-linux-bionic-cuda10.2-cudnn7-py3.9-gcc7-build == *rocm* ]]
Aug 26 02:07:21 ++ SCCACHE_ERROR_LOG=/var/lib/jenkins/sccache_error.log
Aug 26 02:07:21 ++ SCCACHE_IDLE_TIMEOUT=1200
Aug 26 02:07:21 ++ RUST_LOG=sccache::server=error
Aug 26 02:07:21 ++ sccache --start-server
Aug 26 02:07:21 sccache: Starting the server...
Aug 26 02:07:21 ++ sccache --zero-stats
Aug 26 02:07:21 Compile requests                      0

See CircleCI build pytorch_windows_vs2019_py38_cuda10.1_build (3/3)

Step: "Build" (full log | diagnosis details | 🔁 rerun)

Error generating file
reduce.cu
C:/Program Files/NVIDIA GPU Computing Toolkit/CUDA/v10.1/include\thrust/system/cuda/detail/core/util.h(610): error C2065: 'S': undeclared identifier
C:/Program Files/NVIDIA GPU Computing Toolkit/CUDA/v10.1/include\thrust/system/cuda/detail/core/util.h(611): note: see reference to class template instantiation 'thrust::cuda_cub::core::LoadIterator<PtxPlan,It>' being compiled
Retry attempt: 3
reduce.cu
C:/Program Files/NVIDIA GPU Computing Toolkit/CUDA/v10.1/include\thrust/system/cuda/detail/core/util.h(610): error C2065: 'S': undeclared identifier
C:/Program Files/NVIDIA GPU Computing Toolkit/CUDA/v10.1/include\thrust/system/cuda/detail/core/util.h(611): note: see reference to class template instantiation 'thrust::cuda_cub::core::LoadIterator<PtxPlan,It>' being compiled
-- Removing C:/Users/circleci/project/build/caffe2/CMakeFiles/torch_cuda.dir/utils/math/./torch_cuda_generated_reduce.cu.obj
C:/Jenkins/Miniconda3/Library/bin/cmake.exe -E remove C:/Users/circleci/project/build/caffe2/CMakeFiles/torch_cuda.dir/utils/math/./torch_cuda_generated_reduce.cu.obj
CMake Error at torch_cuda_generated_reduce.cu.obj.Release.cmake:281 (message):
  Error generating file
  C:/Users/circleci/project/build/caffe2/CMakeFiles/torch_cuda.dir/utils/math/./torch_cuda_generated_reduce.cu.obj


[5409/6272] cmd.exe /C "cd /D C:\Users\circleci\project\build\caffe2\CMakeFiles\torch_cuda.dir\operators && C:\Jenkins\Miniconda3\Library\bin\cmake.exe -E make_directory C:/Users/circleci/project/build/caffe2/CMakeFiles/torch_cuda.dir/operators/. && C:\Jenkins\Miniconda3\Library\bin\cmake.exe -D verbose:BOOL=ON -D build_configuration:STRING=Release -D generated_file:STRING=C:/Users/circleci/project/build/caffe2/CMakeFiles/torch_cuda.dir/operators/./torch_cuda_generated_elementwise_div_op.cu.obj -D generated_cubin_file:STRING=C:/Users/circleci/project/build/caffe2/CMakeFiles/torch_cuda.dir/operators/./torch_cuda_generated_elementwise_div_op.cu.obj.cubin.txt -P C:/Users/circleci/project/build/caffe2/CMakeFiles/torch_cuda.dir/operators/torch_cuda_generated_elementwise_div_op.cu.obj.Release.cmake"
-- Removing C:/Users/circleci/project/build/caffe2/CMakeFiles/torch_cuda.dir/operators/./torch_cuda_generated_elementwise_div_op.cu.obj
C:/Jenkins/Miniconda3/Library/bin/cmake.exe -E remove C:/Users/circleci/project/build/caffe2/CMakeFiles/torch_cuda.dir/operators/./torch_cuda_generated_elementwise_div_op.cu.obj
-- Generating dependency file: C:/Users/circleci/project/build/caffe2/CMakeFiles/torch_cuda.dir/operators/torch_cuda_generated_elementwise_div_op.cu.obj.NVCC-depend
C:/Users/circleci/project/build/win_tmp/bin/randomtemp.exe -M -D__CUDACC__ C:/Users/circleci/project/caffe2/operators/elementwise_div_op.cu -o C:/Users/circleci/project/build/caffe2/CMakeFiles/torch_cuda.dir/operators/torch_cuda_generated_elementwise_div_op.cu.obj.NVCC-depend -ccbin cl.exe -m64 -Dtorch_cuda_EXPORTS -DUSE_CUDA -DTORCH_CUDA_BUILD_MAIN_LIB -DWIN32_LEAN_AND_MEAN -DTH_BLAS_MKL -D_OPENMP_NOFORCE_MANIFEST -DONNX_ML=1 -DONNXIFI_ENABLE_EXT=1 -DONNX_NAMESPACE=onnx_torch -D_CRT_SECURE_NO_DEPRECATE=1 -DMAGMA_V2 -DIDEEP_USE_MKL -DUSE_EXTERNAL_MZCRC -DMINIZ_DISABLE_ZIP_READER_CRC32_CHECKS -DUSE_DISTRIBUTED -DUSE_C10D_GLOO -Xcompiler ,\"/DWIN32\",\"/D_WINDOWS\",\"/GR\",\"/EHsc\",\"/w\",\"/bigobj\",\"-DUSE_PTHREADPOOL\",\"-openmp:experimental\",\"-IC:/Users/circleci/project/build/win_tmp/mkl/include\",\"-DNDEBUG\",\"-DUSE_KINETO\",\"-DLIBKINETO_NOCUPTI\",\"-DUSE_FBGEMM\",\"-DUSE_XNNPACK\",\"-DSYMBOLICATE_MOBILE_DEBUG_HANDLE\",\"-DEDGE_PROFILER_USE_KINETO\",\"-DHAVE_AVX512_CPU_DEFINITION\",\"-DHAVE_AVX2_CPU_DEFINITION\",\"/MD\",\"/O2\",\"/Ob2\",\"/DNDEBUG\",\"/w\",\"/bigobj\",\"-DNDEBUG\" -Xcompiler /w -w -Xfatbin -compress-all -DONNX_NAMESPACE=onnx_torch --use-local-env -gencode arch=compute_52,code=sm_52 -gencode arch=compute_75,code=sm_75 -Xcudafe --diag_suppress=cc_clobber_ignored,--diag_suppress=integer_sign_change,--diag_suppress=useless_using_declaration,--diag_suppress=set_but_not_used,--diag_suppress=field_without_dll_interface,--diag_suppress=base_class_has_different_dll_interface,--diag_suppress=dll_interface_conflict_none_assumed,--diag_suppress=dll_interface_conflict_dllexport_assumed,--diag_suppress=implicit_return_from_non_void_function,--diag_suppress=unsigned_compare_with_zero,--diag_suppress=declared_but_not_referenced,--diag_suppress=bad_friend_decl --Werror cross-execution-space-call --no-host-device-move-forward -Xcompiler -MD --expt-relaxed-constexpr --expt-extended-lambda -Xcompiler=/wd4819,/wd4503,/wd4190,/wd4244,/wd4251,/wd4275,/wd4522 -Wno-deprecated-gpu-targets --expt-extended-lambda -DCUDA_HAS_FP16=1 -D__CUDA_NO_HALF_OPERATORS__ -D__CUDA_NO_HALF_CONVERSIONS__ -D__CUDA_NO_BFLOAT16_CONVERSIONS__ -D__CUDA_NO_HALF2_OPERATORS__ -DNVCC "-IC:/Program Files/NVIDIA GPU Computing Toolkit/CUDA/v10.1/include" -IC:/Users/circleci/project/build/aten/src -IC:/Users/circleci/project/aten/src -IC:/Users/circleci/project/build -IC:/Users/circleci/project -IC:/Users/circleci/project/build/third_party/gloo -IC:/Users/circleci/project/cmake/../third_party/gloo -IC:/Users/circleci/project/cmake/../third_party/googletest/googlemock/include -IC:/Users/circleci/project/cmake/../third_party/googletest/googletest/include -IC:/Users/circleci/project/third_party/protobuf/src -IC:/Users/circleci/project/build/win_tmp/mkl/include -IC:/Users/circleci/project/third_party/XNNPACK/include -IC:/Users/circleci/project/cmake/../third_party/benchmark/include -IC:/Users/circleci/project/third_party -IC:/Users/circleci/project/cmake/../third_party/eigen -IC:/Jenkins/Miniconda3/include -IC:/Jenkins/Miniconda3/lib/site-packages/numpy/core/include -IC:/Users/circleci/project/cmake/../third_party/pybind11/include -IC:/Users/circleci/project/cmake/../third_party/cudnn_frontend/include -IC:/Users/circleci/project/cmake/../third_party/cub -IC:/Users/circleci/project/build/caffe2/contrib/aten -IC:/Users/circleci/project/third_party/onnx -IC:/Users/circleci/project/build/third_party/onnx -IC:/Users/circleci/project/third_party/foxi -IC:/Users/circleci/project/build/third_party/foxi -IC:/Users/circleci/project/build/win_tmp/magma/include -IC:/Users/circleci/project/third_party/ideep/mkl-dnn/include -IC:/Users/circleci/project/third_party/ideep/include -IC:/Users/circleci/project/build/include -IC:/Users/circleci/project/torch/csrc/distributed -IC:/Users/circleci/project/build/caffe2/aten/src/TH -IC:/Users/circleci/project/aten/src/TH -IC:/Users/circleci/project/build/caffe2/aten/src/THC -IC:/Users/circleci/project/aten/src/THC -IC:/Users/circleci/project/aten/src/ATen/cuda -IC:/Users/circleci/project/build/caffe2/aten/src -IC:/Users/circleci/project/aten/../third_party/catch/single_include -IC:/Users/circleci/project/aten/src/ATen/.. -IC:/Users/circleci/project/build/caffe2/aten/src/ATen -IC:/Users/circleci/project/c10/cuda/../.. -IC:/Users/circleci/project/c10/../ "-IC:/Program Files/NVIDIA Corporation/NvToolsExt/include" -IC:/Users/circleci/project/torch/csrc/api -IC:/Users/circleci/project/torch/csrc/api/include -IC:/Users/circleci/project/build/third_party/ideep/mkl-dnn/include -IC:/Users/circleci/project/third_party/ideep/mkl-dnn/src/../include
elementwise_div_op.cu
-- Generating temporary cmake readable file: C:/Users/circleci/project/build/caffe2/CMakeFiles/torch_cuda.dir/operators/torch_cuda_generated_elementwise_div_op.cu.obj.depend.tmp

2 jobs timed out:

  • pytorch_linux_xenial_cuda11_1_cudnn8_py3_gcc7_build
  • pytorch_linux_bionic_cuda10_2_cudnn7_py3_9_gcc7_build

This comment was automatically generated by Dr. CI (expand for details).Follow this link to opt-out of these comments for your Pull Requests.

Please report bugs/suggestions to the (internal) Dr. CI Users group.

Click here to manually regenerate this comment.

@ezyang ezyang requested review from mruberry and removed request for ezyang August 2, 2021 13:01
@ezyang ezyang added the triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module label Aug 2, 2021
@mruberry
Copy link
Collaborator

mruberry commented Aug 3, 2021

This is cool @AnirudhDagar and the implementation looks correct.

This does need to add an OpInfo for torch.cat and then list torch.concat as an alias of it, however, because otherwise the alias isn't tested. I would just extend this PR with the OpInfo. @krshrimali and @kshitij12345 should be able to help you. I may be unavailable for the rest of the week but if need help merging this PR you can ping @heitorschueroff and @anjali411 and one of them should available to help you.

@AnirudhDagar
Copy link
Contributor Author

AnirudhDagar commented Aug 3, 2021

Thanks for the review @mruberry! No worries, I'll add the OpInfo and update the PR. I'm planning to contribute to other missing opinfos later, so this should be a good start.

@mruberry
Copy link
Collaborator

mruberry commented Aug 3, 2021

Thanks for the review @mruberry! No worries, I'll add the OpInfo and update the PR. I'm planning to contribute to other missing opinfos later, so this should be a good start.

Sounds good! It'll be great to have more OpInfos!

@anjali411 anjali411 self-requested a review August 3, 2021 14:42
@AnirudhDagar
Copy link
Contributor Author

AnirudhDagar commented Aug 7, 2021

I've added OpInfo for cat/concat. There are a few skips that were required though, namely:

cc @krshrimali @anjali411

Copy link
Contributor

@krshrimali krshrimali left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, @AnirudhDagar for working on this. Looks like a great start, left a few suggestions for sample inputs.

@codecov
Copy link

codecov bot commented Aug 9, 2021

Codecov Report

Merging #62560 (c7db642) into master (aa5e3ad) will increase coverage by 12.50%.
The diff coverage is 68.99%.

❗ Current head c7db642 differs from pull request most recent head c624631. Consider uploading reports for the commit c624631 to get more accurate results

@@             Coverage Diff             @@
##           master   #62560       +/-   ##
===========================================
+ Coverage   47.37%   59.88%   +12.50%     
===========================================
  Files         660      683       +23     
  Lines       86440    88279     +1839     
===========================================
+ Hits        40951    52864    +11913     
+ Misses      45489    35415    -10074     

@AnirudhDagar
Copy link
Contributor Author

@krshrimali thanks for reviewing and sharing useful suggestions.
Scalar tensors/arrays (zero-dimensional) are not supported by torch.{cat,concat}/np.concatenate. Thus, I don't think the test will pass anyway, so we should avoid that.

It will be good to test when dim is not passed, so something like a dictionary in the cases is useful

That's a good idea, I've updated numpy reference implementation to accommodate axis instead of dim kwarg.

I was also able to remove the skip for test_variant_consistency_jit with assert_autodiffed=True. I've added the comment for test_jit_alias_remapping which is the only skipped test now.

@anjali411 @kshitij12345
I've tweaked resize_output to handle memory_format which was missing earlier, and now I'm reusing that function in the cat_out cpu/cuda kernel.

@AnirudhDagar AnirudhDagar changed the title Support torch.concat alias to torch.cat Support torch.concat alias to torch.cat & add cat OpInfo Aug 10, 2021
Copy link
Collaborator

@kshitij12345 kshitij12345 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think you can remove this test as it should be covered by test_out in test_ops.py

@onlyCUDA
def test_cat_stack_cross_devices(self, device):
cuda = torch.randn((3, 3), device=device)
cpu = torch.randn((3, 3), device='cpu')
out_cpu = cpu.clone()
out_cuda = cuda.clone()
with self.assertRaisesRegex(RuntimeError,
"Expected all tensors to be on the same device"):
torch.cat((cuda, cpu))
with self.assertRaisesRegex(RuntimeError,
"Expected all tensors to be on the same device"):
torch.cat((cpu, cuda))
with self.assertRaisesRegex(RuntimeError,
"Expected all tensors to be on the same device"):
torch.cat((cpu, cuda), out=out_cuda)
with self.assertRaisesRegex(RuntimeError,
"Expected all tensors to be on the same device"):
torch.cat((cpu, cpu), out=out_cuda)
with self.assertRaisesRegex(RuntimeError,
"Expected all tensors to be on the same device"):
torch.cat((cuda, cuda), out=out_cpu)
# Stack
with self.assertRaisesRegex(RuntimeError,
"Expected all tensors to be on the same device"):
torch.stack((cuda, cpu))
with self.assertRaisesRegex(RuntimeError,
"Expected all tensors to be on the same device"):
torch.stack((cpu, cuda))
with self.assertRaisesRegex(RuntimeError,
"Expected all tensors to be on the same device"):
torch.stack((cpu, cuda), out=out_cuda)
with self.assertRaisesRegex(RuntimeError,
"Expected all tensors to be on the same device"):
torch.stack((cpu, cpu), out=out_cuda)
with self.assertRaisesRegex(RuntimeError,
"Expected all tensors to be on the same device"):
torch.stack((cuda, cuda), out=out_cpu)

cc: @mruberry

Copy link
Collaborator

@mruberry mruberry left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks really good @AnirudhDagar, I have just a couple questions for your review. Looking forward to hearing your thoughts!

@AnirudhDagar
Copy link
Contributor Author

AnirudhDagar commented Aug 19, 2021

Problem with cat_out memory format behaviour

Replying to this review comment here because it became extremely large while explaining. Apologies for that.

@mruberry this seems more interesting than what I expected initially. The reason is that cat_out_cpu and cat_out_cuda both follow a slightly different behaviour when resizing an out tensor.

  • cat_out_cpu: If out= is not the correct shape then the memory format is used from the first tensor.
  • cat_out_cuda: If out= is not the correct shape then the memory format is computed using compute_output_memory_format , calling it here (point of difference) and then that is finally used for reszing. If you look into the implementation it is not directly utilizing the first tensors' memory format unlike the cpu variant. To be precise according to L446 in Shape.cu bool contiguous=True will be assigned even if the first tensor has a different format (eg. channels_last) and some other tensor has a contiguous_format.

So I wrote a separate test (see below) specifically for testing and making sure that the cat_out_cuda fails and passes for different inputs because of it's behaviour.

First test has inputs with different memory_format so it will fail (not behave as mentioned in @mruberry's comment) for cuda but pass for cpu.

Click to expand: FAILING test_cat_out_memory_format
  def test_cat_out_memory_format(self, device):
      inp_size = (4, 4, 4, 4)
      x_cuda  = torch.randn(inp_size, device=device).contiguous(memory_format=torch.channels_last)
      x_cpu  = torch.randn(inp_size, device='cpu').contiguous(memory_format=torch.channels_last)
      y_cuda = torch.randn(inp_size, device=device)
      y_cpu = torch.randn(inp_size, device='cpu')

      # Case 1: if out= is the correct shape then the memory format of out= is respected
      expected_size = (8, 4, 4, 4)
      out_cuda = torch.empty(expected_size, device=device).contiguous(memory_format=torch.contiguous_format)
      res1_cuda = torch.cat((x_cuda, y_cuda), out=out_cuda)
      
      out_cpu = torch.empty(expected_size, device='cpu').contiguous(memory_format=torch.contiguous_format)
      res1_cpu = torch.cat((x_cpu, y_cpu), out=out_cpu)

      self.assertTrue(res1_cuda.is_contiguous(memory_format=torch.contiguous_format))
      self.assertTrue(res1_cpu.is_contiguous(memory_format=torch.contiguous_format))

      # Case 2: if out= is not the correct shape then the memory format is that of the first tensor
      # output size specifically set wrong so that it is reshaped internally in cat 
      out_cuda = torch.empty((0), device=device).contiguous(memory_format=torch.contiguous_format)
      res2_cuda = torch.cat((x_cuda, y_cuda), out=out_cuda)

      out_cpu = torch.empty((0), device='cpu').contiguous(memory_format=torch.contiguous_format)
      res2_cpu = torch.cat((x_cpu, y_cpu), out=out_cpu)

      self.assertTrue(res2_cuda.is_contiguous(memory_format=torch.channels_last))
      self.assertTrue(res2_cpu.is_contiguous(memory_format=torch.channels_last))

This one has inputs with same memory_format but different from out's memory_format. The second will pass for both cpu and cuda.

Click to expand: PASSING test_cat_out_memory_format
  def test_cat_out_memory_format(self, device):
      inp_size = (4, 4, 4, 4)
      x_cuda  = torch.randn(inp_size, device=device).contiguous(memory_format=torch.channels_last)
      x_cpu  = torch.randn(inp_size, device='cpu').contiguous(memory_format=torch.channels_last)
      # Making sure all input tensors follow the same memory_format makes the whole test happy again
      y_cuda = torch.randn(inp_size, device=device).contiguous(memory_format=torch.channels_last)
      y_cpu = torch.randn(inp_size, device='cpu').contiguous(memory_format=torch.channels_last)

      # Case 1: if out= is the correct shape then the memory format of out= is respected
      expected_size = (8, 4, 4, 4)
      out_cuda = torch.empty(expected_size, device=device).contiguous(memory_format=torch.contiguous_format)
      res1_cuda = torch.cat((x_cuda, y_cuda), out=out_cuda)
      
      out_cpu = torch.empty(expected_size, device='cpu').contiguous(memory_format=torch.contiguous_format)
      res1_cpu = torch.cat((x_cpu, y_cpu), out=out_cpu)

      self.assertTrue(res1_cuda.is_contiguous(memory_format=torch.contiguous_format))
      self.assertTrue(res1_cpu.is_contiguous(memory_format=torch.contiguous_format))

      # Case 2: if out= is not the correct shape then the memory format is that of the first tensor
      # output size specifically set wrong so that it is reshaped internally in cat 
      out_cuda = torch.empty((0), device=device).contiguous(memory_format=torch.contiguous_format)
      res2_cuda = torch.cat((x_cuda, y_cuda), out=out_cuda)

      out_cpu = torch.empty((0), device='cpu').contiguous(memory_format=torch.contiguous_format)
      res2_cpu = torch.cat((x_cpu, y_cpu), out=out_cpu)

      self.assertTrue(res2_cuda.is_contiguous(memory_format=torch.channels_last))
      self.assertTrue(res2_cpu.is_contiguous(memory_format=torch.channels_last))

My Opinion

We should fix this behaviour and make both the cuda and cpu variant behave the same way. We'll need to decide which behaviour we go ahead with. I believe all this (including the tests to be added) should be carried out in a separate issue and a PR because this might be BC Breaking. If you feel the same, I can file an issue and send a subsequent PR for this later.

Ps. I know that .is_contiguous() = .is_contiguous(memory_format=torch.contiguous_format). I just went with the latter in the test to be more explicit.

Please let me know if something is unclear.
cc @ngimel @kshitij12345


Additional Comments

@mruberry @kshitij12345 when digging deeper into memory_format, I found that PyTorch supports three of these in total other than the torch.preserve_format:

  • torch.contiguous_format
  • torch.channels_last
  • torch.channels_last_3d (NOT DOCUMENTED)

We should add torch.channels_last_3d to the memory format documentation. I completely missed this because it was undocumented.

I can add this to the docs in a separate PR.

@mruberry
Copy link
Collaborator

I had a chance to sync with @ngimel, @kimishpatel, and @VitalyFedyunin who previously worked on the out= behavior.

I'm worried about changing the behavior so soon to our PyTorch 1.10 branch cut, so I'd like to propose the following:

  • we keep the behavior where the CPU propagates the first tensor's memory format and CUDA only propagates memory format if all the tensors have the same memory format, otherwise it just uses contiguous_format
  • we add a test that this occurs
  • we finish this PR, adding an OpInfo for cat, fixing the out= skips, and adding the aliases (and testing them)

Then...

  • we file an issue that the memory format behavior of cat is different on cpu and CUDA and suggest the fix is to make the behavior consistently like CUDA, only propagating when all the inputs have the same memory_format
  • after the branch cut we make a PR with that fix and validate it doesn't break anyone (part of that validation is making the change and seeing if anyone complains)

How does that sound, @AnirudhDagar? Great discovery!

@AnirudhDagar
Copy link
Contributor Author

AnirudhDagar commented Aug 25, 2021

Thanks, @mruberry for having a discussion on this. I absolutely agree with everything you mentioned, it is definitely a good idea to hold it for after 1.10 branch cut.

Just a small question:

we add a test that this occurs

Should we add this test here or in the PR which will finally implement a consistent behaviour after 1.10 cut? I believe the test has nothing to do with this PR particularly so we can add that later as well. In case we do add it now, once the behaviour is changed, the other PR will also include the relevant updates in the test.

Also other than this test (which we might want to add or leave here), everything is ready for this PR.

Ps. I already have the test ready to go, the second code snippet ("Click to expand: PASSING test_cat_out_memory_format") in my last comment. It is just a matter of adding that here or leaving it for later.

@mruberry
Copy link
Collaborator

Thanks, @mruberry for having a discussion on this. I absolutely agree with everything you mentioned, it is definitely a good idea to hold it for after 1.10 branch cut.

Just a small question:

we add a test that this occurs

Should we add this test here or in the PR which will finally implement a consistent behaviour after 1.10 cut? I believe the test has nothing to do with this PR particularly so we can add that later as well. In case we do add it now, once the behaviour is changed, the other PR will also include the relevant updates in the test.

Also other than this test (which we might want to add or leave here), everything is ready for this PR.

I'd prefer adding the test in this PR (it's pretty simple) just to make the current behavior clear (the test can also link to the relevant issue) but it's OK, we can follow-up with it separately as long as it's clear the behavior isn't changing in this PR.

@AnirudhDagar
Copy link
Contributor Author

@mruberry I've added the test in my recent commits. Let me know if that looks good or needs any improvement. Also raised issue #63998 describing the inconsistent behaviour.

Ps. I know that .is_contiguous() = .is_contiguous(memory_format=torch.contiguous_format) and by default the tensors are contiguous. I still wrote it explicitly in the test for readability and clarity.

self.assertTrue(res1_cuda.is_contiguous(memory_format=torch.contiguous_format))
self.assertTrue(res1_cpu.is_contiguous(memory_format=torch.contiguous_format))

# Case 2: if out= is not the correct shape then the output it is resized internally
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Really nice comments

Copy link
Collaborator

@mruberry mruberry left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice work, @AnirudhDagar! This was a very challenging PR because it required understanding the new out= behavior, identifying a very subtle discrepancy between cat's CPU and CUDA implementation, and then updating the existing code and tests to account for that.

@mruberry
Copy link
Collaborator

mruberry commented Sep 6, 2021

@krshrimali would you like to take another look?

@facebook-github-bot
Copy link
Contributor

@mruberry has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

Copy link
Contributor

@krshrimali krshrimali left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, great work @AnirudhDagar 🎉 Thanks!

@facebook-github-bot
Copy link
Contributor

@mruberry merged this pull request in 1a1fb31.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

cla signed Merged module: python array api Issues related to the Python Array API oncall: jit Add this issue/PR to JIT oncall triage queue open source triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Support concat alias to cat

10 participants