KEMBAR78
Migrate thnn_conv_depthwise2d from THC to ATen by peterbell10 · Pull Request #62006 · pytorch/pytorch · GitHub
Skip to content

Conversation

@peterbell10
Copy link
Collaborator

@peterbell10 peterbell10 commented Jul 22, 2021

Stack from ghstack:

Closes gh-24646, gh-24647

There is no TensorIterator equivalent to these kernels so this is just
migrating the existing kernels over to the ATen style.

I've benchmarked for contiguous tensors with this script:

import torch
shape = (10, 10, 100, 100)
x = torch.randn(*shape, device='cuda')
w = torch.randn((10, 1, 5, 5), device='cuda')

for _ in range(100):
    torch.nn.functional.conv2d(x, w, groups=10)

and similarly for backwards. I see these as the same to within measurement error.

Master Forward (us) This PR Forward (us)
Forward 133.5 133.6
Backward (input) 1,102 1,119
Backward (weight) 2,220 2,217

Differential Revision: D29883676

Closes gh-24646, gh-24647

There is no `TensorIterator` equivalent to these kernels so this is just
migrating the existing kernels over to the ATen style.

I've benchmarked for contiguous tensors with this script:
```
import torch
shape = (10, 10, 100, 100)
x = torch.randn(*shape, device='cuda')
w = torch.randn((10, 1, 5, 5), device='cuda')

for _ in range(100):
    torch.nn.functional.conv2d(x, w, groups=10)
```

and similarly for backwards. I see these as the same to within measurement error.

|                   | Master Forward (us) | This PR Forward (us) |
|------------------:|:-------------------:|:--------------------:|
|           Forward |        133.5        |         133.6        |
|  Backward (input) |        1,102        |         1,119        |
| Backward (weight) |        2,220        |         2,217        |

[ghstack-poisoned]
@facebook-github-bot
Copy link
Contributor

facebook-github-bot commented Jul 22, 2021

💊 CI failures summary and remediations

As of commit c01b02e (more details on the Dr. CI page and at hud.pytorch.org/pr/62006):


  • 5/5 failures introduced in this PR

🕵️ 4 new failures recognized by patterns

The following CI failures do not appear to be due to upstream breakages:

See GitHub Actions build Windows CI (pytorch-win-vs2019-cpu-py3) / test (default, 1, 2, windows.4xlarge) (1/4)

Step: "Run test scripts" (full log | diagnosis details | 🔁 rerun)

2021-07-23T02:33:14.8136122Z ModuleNotFoundError: No module named 'torch.version'
2021-07-23T02:33:14.8081868Z 0228:15f4 @ 05188343 - LdrpPreprocessDllName - INFO: DLL api-ms-win-crt-runtime-l1-1-0.dll was redirected to C:\Windows\SYSTEM32\ucrtbase.dll by API set
2021-07-23T02:33:14.8083087Z 0228:15f4 @ 05188343 - LdrpGetProcedureAddress - INFO: Locating procedure "RtlInitializeSListHead" by name
2021-07-23T02:33:14.8084077Z 0228:15f4 @ 05188343 - LdrpInitializeNode - INFO: Calling init routine 00007FFDC7C33678 for DLL "C:\Jenkins\Miniconda3\DLLs\_lzma.pyd"
2021-07-23T02:33:14.8084833Z 0228:15f4 @ 05188343 - LdrpLoadDllInternal - RETURN: Status: 0x00000000
2021-07-23T02:33:14.8085354Z 0228:15f4 @ 05188343 - LdrLoadDll - RETURN: Status: 0x00000000
2021-07-23T02:33:14.8085981Z 0228:15f4 @ 05188343 - LdrpGetProcedureAddress - INFO: Locating procedure "PyInit__lzma" by name
2021-07-23T02:33:14.8133597Z Traceback (most recent call last):
2021-07-23T02:33:14.8134293Z   File "<string>", line 1, in <module>
2021-07-23T02:33:14.8135012Z   File "C:\actions-runner\_work\pytorch\pytorch\torch\__init__.py", line 29, in <module>
2021-07-23T02:33:14.8135599Z     from .version import __version__ as __version__
2021-07-23T02:33:14.8136122Z ModuleNotFoundError: No module named 'torch.version'
2021-07-23T02:33:14.8174766Z 0228:15f4 @ 05188343 - LdrLoadDll - ENTER: DLL name: api-ms-win-appmodel-runtime-l1-1-2
2021-07-23T02:33:14.8176099Z 0228:15f4 @ 05188343 - LdrpPreprocessDllName - INFO: DLL api-ms-win-appmodel-runtime-l1-1-2 was redirected to C:\Windows\SYSTEM32\kernel.appcore.dll by API set
2021-07-23T02:33:14.8177265Z 0228:15f4 @ 05188343 - LdrpLoadDllInternal - ENTER: DLL name: C:\Windows\SYSTEM32\kernel.appcore.dll
2021-07-23T02:33:14.8178012Z 0228:15f4 @ 05188343 - LdrpFindKnownDll - ENTER: DLL name: kernel.appcore.dll
2021-07-23T02:33:14.8178606Z 0228:15f4 @ 05188343 - LdrpFindKnownDll - RETURN: Status: 0x00000000
2021-07-23T02:33:14.8179290Z 0228:15f4 @ 05188343 - LdrpMinimalMapModule - ENTER: DLL name: C:\Windows\System32\kernel.appcore.dll
2021-07-23T02:33:14.8180031Z ModLoad: 00007ffd`d2d40000 00007ffd`d2d51000   C:\Windows\System32\kernel.appcore.dll
2021-07-23T02:33:14.8180653Z 0228:15f4 @ 05188343 - LdrpMinimalMapModule - RETURN: Status: 0x00000000
2021-07-23T02:33:14.8181594Z 0228:15f4 @ 05188343 - LdrpFindDllActivationContext - INFO: Probing for the manifest of DLL "C:\Windows\System32\kernel.appcore.dll" failed with status 0xc000008a
2021-07-23T02:33:14.8182923Z 0228:15f4 @ 05188343 - LdrpPreprocessDllName - INFO: DLL api-ms-win-core-profile-l1-1-0.dll was redirected to C:\Windows\SYSTEM32\kernelbase.dll by API set

See CircleCI build pytorch_macos_10_13_py3_test (2/4)

Step: "Test" (full log | diagnosis details | 🔁 rerun)

Jul 23 03:39:12 test_remote_message_script_de...yUniqueId(created_on=0, local_id=0) to be created.
Jul 23 03:38:44 frame #12: std::__1::__function::__func<std::__1::__bind<torch::distributed::rpc::ProcessGroupAgent::enqueueRecv(torch::distributed::rpc::RecvWork)::$_6, torch::distributed::rpc::RecvWork>, std::__1::allocator<std::__1::__bind<torch::distributed::rpc::ProcessGroupAgent::enqueueRecv(torch::distributed::rpc::RecvWork)::$_6, torch::distributed::rpc::RecvWork> >, void ()>::operator()() + 42 (0x11b35d22a in libtorch_cpu.dylib)
Jul 23 03:38:44 frame #13: c10::ThreadPool::main_loop(unsigned long) + 569 (0x1146c9369 in libc10.dylib)
Jul 23 03:38:44 frame #14: void* std::__1::__thread_proxy<std::__1::tuple<std::__1::unique_ptr<std::__1::__thread_struct, std::__1::default_delete<std::__1::__thread_struct> >, c10::ThreadPool::ThreadPool(int, int, std::__1::function<void ()>)::$_0> >(void*) + 67 (0x1146c9a13 in libc10.dylib)
Jul 23 03:38:44 frame #15: _pthread_start + 148 (0x7fff6f9cb109 in libsystem_pthread.dylib)
Jul 23 03:38:44 frame #16: thread_start + 15 (0x7fff6f9c6b8b in libsystem_pthread.dylib)
Jul 23 03:38:44 
Jul 23 03:38:45 ok (4.088s)
Jul 23 03:38:53   test_remote_message_dropped_pickle (__main__.FaultyFaultyAgentRpcTestWithSpawn) ... ok (8.252s)
Jul 23 03:39:01   test_remote_message_dropped_pickle_to_self (__main__.FaultyFaultyAgentRpcTestWithSpawn) ... ok (8.369s)
Jul 23 03:39:09   test_remote_message_script_delay_timeout (__main__.FaultyFaultyAgentRpcTestWithSpawn) ... ok (7.338s)
Jul 23 03:39:12   test_remote_message_script_delay_timeout_to_self (__main__.FaultyFaultyAgentRpcTestWithSpawn) ... [E request_callback_no_python.cpp:555] Received error while processing request type 260: falseINTERNAL ASSERT FAILED at "../torch/csrc/distributed/rpc/rref_context.cpp":390, please report a bug to PyTorch. Expected OwnerRRef with id GloballyUniqueId(created_on=0, local_id=0) to be created.
Jul 23 03:39:12 Exception raised from getOwnerRRef at ../torch/csrc/distributed/rpc/rref_context.cpp:390 (most recent call first):
Jul 23 03:39:12 frame #0: c10::Error::Error(c10::SourceLocation, std::__1::basic_string<char, std::__1::char_traits<char>, std::__1::allocator<char> >) + 98 (0x113e276b2 in libc10.dylib)
Jul 23 03:39:12 frame #1: c10::detail::torchCheckFail(char const*, char const*, unsigned int, std::__1::basic_string<char, std::__1::char_traits<char>, std::__1::allocator<char> > const&) + 106 (0x113e25e2a in libc10.dylib)
Jul 23 03:39:12 frame #2: c10::detail::torchInternalAssertFail(char const*, char const*, unsigned int, char const*, std::__1::basic_string<char, std::__1::char_traits<char>, std::__1::allocator<char> > const&) + 64 (0x113e26060 in libc10.dylib)
Jul 23 03:39:12 frame #3: torch::distributed::rpc::RRefContext::getOwnerRRef(torch::distributed::rpc::GloballyUniqueId const&, bool) + 1711 (0x1191b872f in libtorch_cpu.dylib)
Jul 23 03:39:12 frame #4: torch::distributed::rpc::RequestCallbackNoPython::assignOwnerRRef(torch::distributed::rpc::GloballyUniqueId const&, torch::distributed::rpc::GloballyUniqueId const&, c10::intrusive_ptr<c10::ivalue::Future, c10::detail::intrusive_target_default_null_type<c10::ivalue::Future> >) const + 86 (0x1191a2f86 in libtorch_cpu.dylib)
Jul 23 03:39:12 frame #5: torch::distributed::rpc::RequestCallbackImpl::processScriptRemoteCall(torch::distributed::rpc::RpcCommandBase&, std::__1::vector<c10::Stream, std::__1::allocator<c10::Stream> >) const + 376 (0x1152757b8 in libtorch_python.dylib)
Jul 23 03:39:12 frame #6: torch::distributed::rpc::RequestCallbackNoPython::processRpc(torch::distributed::rpc::RpcCommandBase&, torch::distributed::rpc::MessageType const&, std::__1::vector<c10::Stream, std::__1::allocator<c10::Stream> >) const + 437 (0x1191a1bd5 in libtorch_cpu.dylib)
Jul 23 03:39:12 frame #7: torch::distributed::rpc::RequestCallbackImpl::processRpcWithErrors(torch::distributed::rpc::RpcCommandBase&, torch::distributed::rpc::MessageType const&, std::__1::vector<c10::Stream, std::__1::allocator<c10::Stream> >) const + 74 (0x11527652a in libtorch_python.dylib)
Jul 23 03:39:12 frame #8: c10::intrusive_ptr<c10::ivalue::Future, c10::detail::intrusive_target_default_null_type<c10::ivalue::Future> > c10::ivalue::Future::thenAsync<torch::distributed::rpc::RequestCallbackNoPython::processMessage(torch::distributed::rpc::Message&, std::__1::vector<c10::Stream, std::__1::allocator<c10::Stream> >) const::$_1>(torch::distributed::rpc::RequestCallbackNoPython::processMessage(torch::distributed::rpc::Message&, std::__1::vector<c10::Stream, std::__1::allocator<c10::Stream> >) const::$_1, std::__1::shared_ptr<c10::Type>)::'lambda'(c10::ivalue::Future&)::operator()(c10::ivalue::Future&) + 223 (0x1191a989f in libtorch_cpu.dylib)

See CircleCI build pytorch_linux_xenial_py3_6_gcc5_4_build (3/4)

Step: "(Optional) Merge target branch" (full log | diagnosis details | 🔁 rerun)

Automatic merge failed; fix conflicts and then commit the result.
CONFLICT (add/add): Merge conflict in .circleci/generate_config_yml.py
Auto-merging .circleci/generate_config_yml.py
CONFLICT (add/add): Merge conflict in .circleci/config.yml
Auto-merging .circleci/config.yml
CONFLICT (add/add): Merge conflict in .circleci/cimodel/data/simple/docker_definitions.py
Auto-merging .circleci/cimodel/data/simple/docker_definitions.py
CONFLICT (add/add): Merge conflict in .circleci/cimodel/data/pytorch_build_data.py
Auto-merging .circleci/cimodel/data/pytorch_build_data.py
CONFLICT (add/add): Merge conflict in .circleci/cimodel/data/binary_build_data.py
Auto-merging .circleci/cimodel/data/binary_build_data.py
Automatic merge failed; fix conflicts and then commit the result.


Exited with code exit status 1

See CircleCI build pytorch_xla_linux_bionic_py3_6_clang9_build (4/4)

Step: "(Optional) Merge target branch" (full log | diagnosis details | 🔁 rerun)

Automatic merge failed; fix conflicts and then commit the result.
CONFLICT (add/add): Merge conflict in .circleci/generate_config_yml.py
Auto-merging .circleci/generate_config_yml.py
CONFLICT (add/add): Merge conflict in .circleci/config.yml
Auto-merging .circleci/config.yml
CONFLICT (add/add): Merge conflict in .circleci/cimodel/data/simple/docker_definitions.py
Auto-merging .circleci/cimodel/data/simple/docker_definitions.py
CONFLICT (add/add): Merge conflict in .circleci/cimodel/data/pytorch_build_data.py
Auto-merging .circleci/cimodel/data/pytorch_build_data.py
CONFLICT (add/add): Merge conflict in .circleci/cimodel/data/binary_build_data.py
Auto-merging .circleci/cimodel/data/binary_build_data.py
Automatic merge failed; fix conflicts and then commit the result.


Exited with code exit status 1


1 failure not recognized by patterns:

Job Step Action
GitHub Actions Test tools / test Install dependencies 🔁 rerun

Preview docs built from this PR

This comment was automatically generated by Dr. CI (expand for details).Follow this link to opt-out of these comments for your Pull Requests.

Please report bugs/suggestions to the (internal) Dr. CI Users group.

Click here to manually regenerate this comment.

Closes gh-24646, gh-24647

There is no `TensorIterator` equivalent to these kernels so this is just
migrating the existing kernels over to the ATen style.

I've benchmarked for contiguous tensors with this script:
```
import torch
shape = (10, 10, 100, 100)
x = torch.randn(*shape, device='cuda')
w = torch.randn((10, 1, 5, 5), device='cuda')

for _ in range(100):
    torch.nn.functional.conv2d(x, w, groups=10)
```

and similarly for backwards. I see these as the same to within measurement error.

|                   | Master Forward (us) | This PR Forward (us) |
|------------------:|:-------------------:|:--------------------:|
|           Forward |        133.5        |         133.6        |
|  Backward (input) |        1,102        |         1,119        |
| Backward (weight) |        2,220        |         2,217        |

[ghstack-poisoned]
peterbell10 added a commit that referenced this pull request Jul 22, 2021
Closes gh-24646, gh-24647

There is no `TensorIterator` equivalent to these kernels so this is just
migrating the existing kernels over to the ATen style.

I've benchmarked for contiguous tensors with this script:
```
import torch
shape = (10, 10, 100, 100)
x = torch.randn(*shape, device='cuda')
w = torch.randn((10, 1, 5, 5), device='cuda')

for _ in range(100):
    torch.nn.functional.conv2d(x, w, groups=10)
```

and similarly for backwards. I see these as the same to within measurement error.

|                   | Master Forward (us) | This PR Forward (us) |
|------------------:|:-------------------:|:--------------------:|
|           Forward |        133.5        |         133.6        |
|  Backward (input) |        1,102        |         1,119        |
| Backward (weight) |        2,220        |         2,217        |

ghstack-source-id: 9ee4fdd
Pull Request resolved: #62006
@peterbell10 peterbell10 requested a review from ngimel July 22, 2021 01:55
@ezyang ezyang removed their request for review July 22, 2021 14:11
@albanD albanD removed their request for review July 22, 2021 18:38
Closes gh-24646, gh-24647

There is no `TensorIterator` equivalent to these kernels so this is just
migrating the existing kernels over to the ATen style.

I've benchmarked for contiguous tensors with this script:
```
import torch
shape = (10, 10, 100, 100)
x = torch.randn(*shape, device='cuda')
w = torch.randn((10, 1, 5, 5), device='cuda')

for _ in range(100):
    torch.nn.functional.conv2d(x, w, groups=10)
```

and similarly for backwards. I see these as the same to within measurement error.

|                   | Master Forward (us) | This PR Forward (us) |
|------------------:|:-------------------:|:--------------------:|
|           Forward |        133.5        |         133.6        |
|  Backward (input) |        1,102        |         1,119        |
| Backward (weight) |        2,220        |         2,217        |

[ghstack-poisoned]
conv_depthwise2d_backward_out(
*self, *grad_output, grad_input, *weight,
kernel_size[1], kernel_size[0],
stride[1], stride[0],
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is scary, can we add a test that would have caught it?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've added a gradcheck test that fails for the previous version.

static inline bool cudnn_conv_use_channels_last(const at::Tensor& input, const at::Tensor& weight) {
// disable NHWC for float64 input.
if (!detail::getCUDAHooks().compiledWithCuDNN() ||
if (!at::detail::getCUDAHooks().compiledWithCuDNN() ||
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Without this I get compilation errors because it's looking in the wrong namespace:

/home/peter/git/pytorch/aten/src/ATen/native/ConvUtils.h(87): error: namespace "at::native::detail" has no member "getCUDAHooks"

Closes gh-24646, gh-24647

There is no `TensorIterator` equivalent to these kernels so this is just
migrating the existing kernels over to the ATen style.

I've benchmarked for contiguous tensors with this script:
```
import torch
shape = (10, 10, 100, 100)
x = torch.randn(*shape, device='cuda')
w = torch.randn((10, 1, 5, 5), device='cuda')

for _ in range(100):
    torch.nn.functional.conv2d(x, w, groups=10)
```

and similarly for backwards. I see these as the same to within measurement error.

|                   | Master Forward (us) | This PR Forward (us) |
|------------------:|:-------------------:|:--------------------:|
|           Forward |        133.5        |         133.6        |
|  Backward (input) |        1,102        |         1,119        |
| Backward (weight) |        2,220        |         2,217        |

[ghstack-poisoned]
peterbell10 added a commit that referenced this pull request Jul 23, 2021
Closes gh-24646, gh-24647

There is no `TensorIterator` equivalent to these kernels so this is just
migrating the existing kernels over to the ATen style.

I've benchmarked for contiguous tensors with this script:
```
import torch
shape = (10, 10, 100, 100)
x = torch.randn(*shape, device='cuda')
w = torch.randn((10, 1, 5, 5), device='cuda')

for _ in range(100):
    torch.nn.functional.conv2d(x, w, groups=10)
```

and similarly for backwards. I see these as the same to within measurement error.

|                   | Master Forward (us) | This PR Forward (us) |
|------------------:|:-------------------:|:--------------------:|
|           Forward |        133.5        |         133.6        |
|  Backward (input) |        1,102        |         1,119        |
| Backward (weight) |        2,220        |         2,217        |

ghstack-source-id: b585bf4
Pull Request resolved: #62006
@ngimel
Copy link
Collaborator

ngimel commented Jul 23, 2021

Test failures looks unrelated.

@ngimel
Copy link
Collaborator

ngimel commented Jul 23, 2021

@ngimel has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

@facebook-github-bot
Copy link
Contributor

@ngimel merged this pull request in de3a4eb.

@peterbell10
Copy link
Collaborator Author

peterbell10 commented Jul 27, 2021

@ngimel check_backward_compatibility is failing and I missed it because that CI job wasn't run on this PR for some reason.
https://app.circleci.com/pipelines/github/pytorch/pytorch/356638/workflows/3485a5a2-1869-4368-b00e-3beb36556dcb/jobs/15037920?invite=true#step-107-3014

@facebook-github-bot
Copy link
Contributor

This pull request has been reverted by acaac70.

@ejguan
Copy link
Contributor

ejguan commented Jul 27, 2021

Reverting this PR since it breaks check_backward_compatibility. Please add BC-breaking changes to the allow list in https://github.com/pytorch/pytorch/blob/master/test/backward_compatibility/check_backward_compatibility.py

@facebook-github-bot facebook-github-bot deleted the gh/peterbell10/100/head branch July 31, 2021 14:17
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants