KEMBAR78
first cut of adding backend fallback for conjugation by bdhirsh · Pull Request #43702 · pytorch/pytorch · GitHub
Skip to content

Conversation

@bdhirsh
Copy link
Contributor

@bdhirsh bdhirsh commented Aug 27, 2020

Stack from ghstack:

bdhirsh added a commit that referenced this pull request Aug 27, 2020
ghstack-source-id: 10804cc
Pull Request resolved: #43702
// Is it reaonsable to set the memory on the stack in place like this?
// I don't see other examples of fallbacks doing this,
// so instead I'm pushing the new conjugated vector onto the stack
auto conjugated_tensor = tensor.conj();
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

See the comment here- I don't think I Have a clear understanding how a given fallback in the dispatch workflow is supposed to modify the stack.

My understanding is that the stack represents the call stack, i.e. the top of it contains the arguments to the kernel function that will eventually get called after all other dispatch fallback functions.

Should my conjugate fallback be modifying those argument in place? Pushing new arguments onto the stack?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The stack is a mutable data structure, and the invariants are that when you call some downstream function, the top N arguments of the stack (however many arguments the function takes) are the arguments. On return, the invariant is that the input arguments had been popped from the stack, and the outputs are pushed onto the stack.

before stack: [arg1, arg2, arg3]  # NB: maybe it's the other way, I can never remember
call function
after stack: [ret1, ret2]

As the stack is mutable, an easy way you can modify the stack is simply by overwriting the direct entries with the new, explicitly conjugated (not using the bit) version. I'm not sure if the stack API has a direct way of expressing this, but stacks are just vectors so you can implement it by hand yourself.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, I looked over your implementation and it doesn't work. I can give you more hints about how to fix it, but perhaps the invariant described above is enough?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

BTW, the use of conj() here directly is also unlikely to work, because you still have the conjugate bit on the tensor and you'll redispatch here. You need an operation like conj_materialize() which will take a tensor whose conjugate is bit, and returns the conjugated version without the conjugate bit set. You could probably do this by first allocating a new tensor ala conj_view with the conjugate bit turned off, and then calling conj()

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks- I fixed the issues described (I think), but I'm still hitting issues. It might clarify my understanding of the stack in relation to a dispatch chain:

A dispatch "chain" is a sequence of function calls, consisting of (possibly several) "pre-processing" functions followed by a single backend kernel function. The set of functions called are in order of the priority of the dispatch keys, and are determined by:

  1. the set of dispatch keys in the current context (in the tensor arguments, and local/global context)
    (2) the mappings from from a given op/dispatch key to kernel functions or "fallback" functions. In the case of this conjugate function, we want the function to be called for all tensors with the Conjugate dispatch key set, regardless of the op)

A reasonable dispatch chain in the case of my change: Conjugate -> Add(cpu).

Now, I'm currently thinking of the stack as being passed along to each of the functions in the dispatch chain, such that the state of the stack at the end of the first function is the same as the state at the beginning of the second function.

E.g. suppose: we have:
Tensor a = ... Tensor b = ... Tensor c = a.conj_view() + b

Invoking the "add" function should eventually invoke the dispatcher, first calling conjugateFallback() and then the add() kernel. I'd expect the stack at the beginning of the call to conjugateFallback to look like:
[a (tensor, conj bit set); b (tensor, conj bit not set)]
and after the call to look like:
[a (conj bit NOT set, but the imaginary part negated); b (unchanged, conj bit not set)]
(What we'd want the stack to look like at the beginning of the call to the add() kernel)

Which I believe is what the above is now doing. I'm currently hitting an issue however, where the arguments currently at the top of the stack are not being recognized as tensors (their tag is set to c10::IValue::Tag::None). Later, the call to the add() kernel function fails with "There were no tensor arguments to this function (e.g., you passed an empty list of Tensors)", which makes me think I corrupted the stack somewhere in my conjugate fallback.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This description looks accurate.


TEST(BackendFallbackTest, ConjugateTest) {
auto m = MAKE_TORCH_LIBRARY_IMPL(_, TESTING_ONLY_GenericMode);
m.fallback(torch::CppFunction::makeFromBoxedFunction<&conjugateFallback>());
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Couldn't get this line to compile :(

/home/hirsheybar/pytorch/build/../aten/src/ATen/core/boxing/KernelFunction_impl.h:22: undefined reference to at::conjugateFallback(c10::OperatorHandle const&, std::vector<c10::IValue, std::allocatorc10::IValue >*)'
collect2: error: ld returned 1 exit status`

conjugateFallback() is defined in ConjugateFallback.h, included at the top of this file. It's in the at:: namespace, which we're using here as well. I also tried replacing in the function header torch::jit::Stack* with std::vector<c10::IValue, std::allocator<c10::IValue> >* as the error suggests which didn't fix (I'm assuming Stack is aliased to that type).

Tensor conj_view(const Tensor& self) {
Tensor self_;
if (self.is_quantized()) {
auto impl = c10::make_intrusive<QTensorImpl>(
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I pulled the quantized logic from alias_with_sizes_and_strides, but I'm not actually sure if I need it.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We don't have quantized complex, so probably not :)

return !at::impl::variable_excluded_from_dispatch();
}

// TODO: do I need to add this method directly on the Tensor to implement conj_view() in UnaryOps?
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I added a getter/setter for conjugate() on the TensorImpl class, but I needed to add the getter to the Tensor class as well in order to flip the state correctly in the conj_view() function, since it deals with Tensors and not TensorImpls. This doesn't feel right though- let me know if I'm missing something.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is OK for a prototype. In the real implementation we'll probably have to bikeshed under what name we should expose this functionality, and probably it should get a native functions schema.

@bdhirsh bdhirsh requested a review from ezyang August 27, 2020 14:34
@dr-ci
Copy link

dr-ci bot commented Aug 27, 2020

💊 CI failures summary and remediations

As of commit d31d4fe (more details on the Dr. CI page):



🕵️ 10 new failures recognized by patterns

The following CI failures do not appear to be due to upstream breakages:

See CircleCI build pytorch_linux_xenial_py3_clang5_mobile_build (1/10)

Step: "Build" (full log | diagnosis details | 🔁 rerun)

Tmp/CheckSymbolExists.c: In function \'main\':\n/var/lib/jenkins/workspace/build_test_custom_build/build_default_libtorch/CMakeFiles/CMakeTmp/CheckSymbolExists.c:8:19: error: \'strtod_l\' undeclared (first use in this function)\n return ((int*)(&strtod_l))[argc];\n ^\n/var/lib/jenkins/workspace/build_test_custom_build/build_default_libtorch/CMakeFiles/CMakeTmp/CheckSymbolExists.c:8:19: note: each undeclared identifier is reported only once for each function it appears in\n" }
Sep 08 14:19:39     input: Tensor) -> Tensor: 
Sep 08 14:19:39     input0 = torch._convolution(input, self.weight, None, [2, 2], [1, 1], [1, 1], False, [0, 0], 1, False, False, True, True) 
Sep 08 14:19:39              ~~~~~~~~~~~~~~~~~~ <--- HERE 
Sep 08 14:19:39     return input0 
Sep 08 14:19:39  
Sep 08 14:19:39 test/mobile/custom_build/build.sh: line 81: 14102 Aborted                 (core dumped) ./Predictor "${MODEL}" > output.txt 
Sep 08 14:19:39 + sccache_epilogue 
Sep 08 14:19:39 + echo '=================== sccache compilation log ===================' 
Sep 08 14:19:39 + python /var/lib/jenkins/workspace/.jenkins/pytorch/print_sccache_log.py /var/lib/jenkins/sccache_error.log 
Sep 08 14:19:39 =================== sccache compilation log =================== 
mp/CheckSymbolExists.c: In function \'main\':\n/var/lib/jenkins/workspace/build_test_custom_build/build_default_libtorch/CMakeFiles/CMakeTmp/CheckSymbolExists.c:8:19: error: \'strtod_l\' undeclared (first use in this function)\n   return ((int*)(&strtod_l))[argc];\n                   ^\n/var/lib/jenkins/workspace/build_test_custom_build/build_default_libtorch/CMakeFiles/CMakeTmp/CheckSymbolExists.c:8:19: note: each undeclared identifier is reported only once for each function it appears in\n" } 
Sep 08 14:19:39  
Sep 08 14:19:39 =========== If your build fails, please take a look at the log above for possible reasons =========== 
Sep 08 14:19:39 + echo '=========== If your build fails, please take a look at the log above for possible reasons ===========' 
Sep 08 14:19:39 + sccache --show-stats 
Sep 08 14:19:39 Compile requests              2239 
Sep 08 14:19:39 Compile requests executed     1604 
Sep 08 14:19:39 Cache hits                    1133 
Sep 08 14:19:39 Cache misses                   467 
Sep 08 14:19:39 Cache timeouts                   0 
Sep 08 14:19:39 Cache read errors                0 

See CircleCI build pytorch_linux_xenial_py3_clang5_mobile_custom_build_dynamic (2/10)

Step: "Build" (full log | diagnosis details | 🔁 rerun)

lExists.c: In function \'main\':\n/var/lib/jenkins/workspace/build_test_custom_build/build_custom_libtorch_dynamic/CMakeFiles/CMakeTmp/CheckSymbolExists.c:8:19: error: \'strtod_l\' undeclared (first use in this function)\n return ((int*)(&strtod_l))[argc];\n ^\n/var/lib/jenkins/workspace/build_test_custom_build/build_custom_libtorch_dynamic/CMakeFiles/CMakeTmp/CheckSymbolExists.c:8:19: note: each undeclared identifier is reported only once for each function it appears in\n" }
Sep 08 14:31:19     input: Tensor) -> Tensor: 
Sep 08 14:31:19     input0 = torch._convolution(input, self.weight, None, [2, 2], [1, 1], [1, 1], False, [0, 0], 1, False, False, True, True) 
Sep 08 14:31:19              ~~~~~~~~~~~~~~~~~~ <--- HERE 
Sep 08 14:31:19     return input0 
Sep 08 14:31:19  
Sep 08 14:31:19 test/mobile/custom_build/build.sh: line 81: 21609 Aborted                 (core dumped) ./Predictor "${MODEL}" > output.txt 
Sep 08 14:31:19 =================== sccache compilation log =================== 
Sep 08 14:31:19 + sccache_epilogue 
Sep 08 14:31:19 + echo '=================== sccache compilation log ===================' 
Sep 08 14:31:19 + python /var/lib/jenkins/workspace/.jenkins/pytorch/print_sccache_log.py /var/lib/jenkins/sccache_error.log 
Exists.c: In function \'main\':\n/var/lib/jenkins/workspace/build_test_custom_build/build_custom_libtorch_dynamic/CMakeFiles/CMakeTmp/CheckSymbolExists.c:8:19: error: \'strtod_l\' undeclared (first use in this function)\n   return ((int*)(&strtod_l))[argc];\n                   ^\n/var/lib/jenkins/workspace/build_test_custom_build/build_custom_libtorch_dynamic/CMakeFiles/CMakeTmp/CheckSymbolExists.c:8:19: note: each undeclared identifier is reported only once for each function it appears in\n" } 
Sep 08 14:31:19  
Sep 08 14:31:19 =========== If your build fails, please take a look at the log above for possible reasons =========== 
Sep 08 14:31:19 + echo '=========== If your build fails, please take a look at the log above for possible reasons ===========' 
Sep 08 14:31:19 + sccache --show-stats 
Sep 08 14:31:19 Compile requests              2828 
Sep 08 14:31:19 Compile requests executed     2192 
Sep 08 14:31:19 Cache hits                      22 
Sep 08 14:31:19 Cache misses                  2166 
Sep 08 14:31:19 Cache timeouts                   0 
Sep 08 14:31:19 Cache read errors                0 

See CircleCI build pytorch_linux_xenial_cuda9_2_cudnn7_py3_gcc5_4_build (3/10)

Step: "(Optional) Merge target branch" (full log | diagnosis details | 🔁 rerun)

Automatic merge failed; fix conflicts and then commit the result.
CONFLICT (add/add): Merge conflict in .circleci/scripts/binary_linux_build.sh 
Auto-merging .circleci/scripts/binary_linux_build.sh 
CONFLICT (add/add): Merge conflict in .circleci/config.yml 
Auto-merging .circleci/config.yml 
CONFLICT (add/add): Merge conflict in .circleci/cimodel/data/pytorch_build_definitions.py 
Auto-merging .circleci/cimodel/data/pytorch_build_definitions.py 
CONFLICT (add/add): Merge conflict in .circleci/cimodel/data/pytorch_build_data.py 
Auto-merging .circleci/cimodel/data/pytorch_build_data.py 
CONFLICT (add/add): Merge conflict in .circleci/cimodel/data/binary_build_data.py 
Auto-merging .circleci/cimodel/data/binary_build_data.py 
Automatic merge failed; fix conflicts and then commit the result. 

See CircleCI build pytorch_linux_bionic_py3_6_clang9_test (4/10)

Step: "Run tests" (full log | diagnosis details | 🔁 rerun)

Sep 08 16:15:56 ERROR:sccache::server: Compilation failed: Output { status: ExitStatus(ExitStatus(256)), stdout: "", stderr: "/var/lib/jenkins/.cache/torch_extensions/test_compilation_error_formatting/main.cpp: In function \'int main()\':\n/var/lib/jenkins/.cache/torch_extensions/test_compilation_error_formatting/main.cpp:2:23: error: expected \';\' before \'}\' token\n int main() { return 0 }\n ^\n" }
Sep 08 16:15:55     raise RuntimeError(err_message) 
Sep 08 16:15:55 RuntimeError: test_torch failed! 
Sep 08 16:15:56  
Sep 08 16:15:56 real	36m56.136s 
Sep 08 16:15:56 user	42m14.495s 
Sep 08 16:15:56 sys	2m59.395s 
Sep 08 16:15:56 + cleanup 
Sep 08 16:15:56 + retcode=1 
Sep 08 16:15:56 + set +x 
Sep 08 16:15:56 =================== sccache compilation log =================== 
Sep 08 16:15:56 ERROR:sccache::server: Compilation failed: Output { status: ExitStatus(ExitStatus(256)), stdout: "", stderr: "/var/lib/jenkins/.cache/torch_extensions/test_compilation_error_formatting/main.cpp: In function \'int main()\':\n/var/lib/jenkins/.cache/torch_extensions/test_compilation_error_formatting/main.cpp:2:23: error: expected \';\' before \'}\' token\n int main() { return 0 }\n                       ^\n" } 
Sep 08 16:15:56  
Sep 08 16:15:56 =========== If your build fails, please take a look at the log above for possible reasons =========== 
Sep 08 16:15:56 Compile requests                 65 
Sep 08 16:15:56 Compile requests executed        35 
Sep 08 16:15:56 Cache hits                        2 
Sep 08 16:15:56 Cache misses                     32 
Sep 08 16:15:56 Cache timeouts                    0 
Sep 08 16:15:56 Cache read errors                 0 
Sep 08 16:15:56 Forced recaches                   0 
Sep 08 16:15:56 Cache write errors                0 

See CircleCI build pytorch_macos_10_13_py3_test (5/10)

Step: "Test" (full log | diagnosis details | 🔁 rerun)

Sep 08 10:12:23 AssertionError: False is not true : Tensor.conj_view is missing documentation
Sep 08 10:12:23   test_view_view_cpu (__main__.TestViewOpsCPU) ... ok (0.002s) 
Sep 08 10:12:23  
Sep 08 10:12:23 ====================================================================== 
Sep 08 10:12:23 FAIL [0.008s]: test_doc (__main__.TestTorch) 
Sep 08 10:12:23 ---------------------------------------------------------------------- 
Sep 08 10:12:23 Traceback (most recent call last): 
Sep 08 10:12:23   File "test_torch.py", line 216, in test_doc 
Sep 08 10:12:23     'sparse_resize_and_clear_', 
Sep 08 10:12:23   File "test_torch.py", line 187, in test_namespace 
Sep 08 10:12:23     self.assertTrue(has_doc, '{} is missing documentation'.format(full_name)) 
Sep 08 10:12:23 AssertionError: False is not true : Tensor.conj_view is missing documentation 
Sep 08 10:12:23  
Sep 08 10:12:23 ---------------------------------------------------------------------- 
Sep 08 10:12:23 Ran 2901 tests in 290.239s 
Sep 08 10:12:23  
Sep 08 10:12:23 FAILED (failures=1, skipped=154) 
Sep 08 10:12:23  
Sep 08 10:12:23 Generating XML reports... 
Sep 08 10:12:23 Generated XML report: test-reports/dist-gloo/TEST-TestTensorDeviceOpsCPU-20200908100733.xml 
Sep 08 10:12:23 Generated XML report: test-reports/dist-gloo/TEST-TestTorch-20200908100733.xml 
Sep 08 10:12:24 Generated XML report: test-reports/dist-gloo/TEST-TestTorchDeviceTypeCPU-20200908100733.xml 

See CircleCI build pytorch_windows_vs2019_py36_cuda10.1_test2 (6/10)

Step: "Test" (full log | diagnosis details | 🔁 rerun)

AssertionError: False is not true : Tensor.conj_view is missing documentation
  test_view_view_cuda (__main__.TestViewOpsCUDA) ... ok (0.001s) 
 
====================================================================== 
FAIL [0.005s]: test_doc (__main__.TestTorch) 
---------------------------------------------------------------------- 
Traceback (most recent call last): 
  File "test_torch.py", line 216, in test_doc 
    'sparse_resize_and_clear_', 
  File "test_torch.py", line 187, in test_namespace 
    self.assertTrue(has_doc, '{} is missing documentation'.format(full_name)) 
AssertionError: False is not true : Tensor.conj_view is missing documentation 
 
---------------------------------------------------------------------- 
Ran 6626 tests in 584.486s 
 
FAILED (failures=1, skipped=383) 
 
Generating XML reports... 
Generated XML report: test-reports\python-unittest\TEST-TestDevicePrecisionCUDA-20200908155903.xml 
Generated XML report: test-reports\python-unittest\TEST-TestTensorDeviceOpsCPU-20200908155903.xml 
Generated XML report: test-reports\python-unittest\TEST-TestTensorDeviceOpsCUDA-20200908155903.xml 

See CircleCI build pytorch_linux_xenial_py3_clang5_asan_test2 (7/10)

Step: "Run tests" (full log | diagnosis details | 🔁 rerun)

Sep 08 15:50:48 SUMMARY: UndefinedBehaviorSanitizer: undefined-behavior /var/lib/jenkins/workspace/aten/src/ATen/Utils.cpp:11:3 in
Sep 08 15:50:48     #7 0x55ef750197eb in PyEval_EvalCode /tmp/build/80754af9/python_1588903631989/work/Python/ceval.c:731 
Sep 08 15:50:48     #8 0x55ef75099e73 in run_mod /tmp/build/80754af9/python_1588903631989/work/Python/pythonrun.c:1025 
Sep 08 15:50:48     #9 0x55ef75099f0c in PyRun_StringFlags /tmp/build/80754af9/python_1588903631989/work/Python/pythonrun.c:949 
Sep 08 15:50:48     #10 0x55ef75099f6e in PyRun_SimpleStringFlags /tmp/build/80754af9/python_1588903631989/work/Python/pythonrun.c:445 
Sep 08 15:50:48     #11 0x55ef7509dd72 in run_command /tmp/build/80754af9/python_1588903631989/work/Modules/main.c:301 
Sep 08 15:50:48     #12 0x55ef7509dd72 in Py_Main /tmp/build/80754af9/python_1588903631989/work/Modules/main.c:749 
Sep 08 15:50:48     #13 0x55ef74f67f2d in main /tmp/build/80754af9/python_1588903631989/work/Programs/python.c:69 
Sep 08 15:50:48     #14 0x7f828e16f83f in __libc_start_main /build/glibc-e6zv40/glibc-2.23/csu/../csu/libc-start.c:291 
Sep 08 15:50:48     #15 0x55ef7504727e in _start /home/rdonnelly/mc/conda-bld/compilers_linux-64_1534865402226/work/.build/src/glibc-2.12.2/csu/../sysdeps/x86_64/elf/start.S:103 
Sep 08 15:50:48  
Sep 08 15:50:48 SUMMARY: UndefinedBehaviorSanitizer: undefined-behavior /var/lib/jenkins/workspace/aten/src/ATen/Utils.cpp:11:3 in  
Sep 08 15:50:48 + retcode=1 
Sep 08 15:50:48 + set -e 
Sep 08 15:50:48 + return 1 
Sep 08 15:50:48 + [[ pytorch-linux-xenial-py3-clang5-asan-test2 == *-NO_AVX-* ]] 
Sep 08 15:50:48 + [[ pytorch-linux-xenial-py3-clang5-asan-test2 == *-NO_AVX2-* ]] 
Sep 08 15:50:48 + '[' -n https://github.com/pytorch/pytorch/pull/43702 ']' 
Sep 08 15:50:48 ++ mktemp 
Sep 08 15:50:48 + DETERMINE_FROM=/tmp/tmp.rzDN7Xauqm 
Sep 08 15:50:48 + file_diff_from_base /tmp/tmp.rzDN7Xauqm 
Sep 08 15:50:48 + set +e 

See CircleCI build pytorch_linux_bionic_py3_8_gcc9_coverage_test (8/10)

Step: "Run tests" (full log | diagnosis details | 🔁 rerun)

Sep 08 16:45:43 ERROR:sccache::server: Compilation failed: Output { status: ExitStatus(ExitStatus(256)), stdout: "", stderr: "/var/lib/jenkins/.cache/torch_extensions/test_compilation_error_formatting/main.cpp: In function ‘int main()’:\n/var/lib/jenkins/.cache/torch_extensions/test_compilation_error_formatting/main.cpp:2:22: error: expected ‘;’ before ‘}’ token\n 2 | int main() { return 0 }\n | ^~\n | ;\n" }
Sep 08 16:45:43     raise RuntimeError(err_message) 
Sep 08 16:45:43 RuntimeError: test_torch failed! 
Sep 08 16:45:43  
Sep 08 16:45:43 real	40m58.358s 
Sep 08 16:45:43 user	46m0.887s 
Sep 08 16:45:43 sys	1m42.044s 
Sep 08 16:45:43 + cleanup 
Sep 08 16:45:43 + retcode=1 
Sep 08 16:45:43 + set +x 
Sep 08 16:45:43 =================== sccache compilation log =================== 
Sep 08 16:45:43 ERROR:sccache::server: Compilation failed: Output { status: ExitStatus(ExitStatus(256)), stdout: "", stderr: "/var/lib/jenkins/.cache/torch_extensions/test_compilation_error_formatting/main.cpp: In function ‘int main()’:\n/var/lib/jenkins/.cache/torch_extensions/test_compilation_error_formatting/main.cpp:2:22: error: expected ‘;’ before ‘}’ token\n    2 | int main() { return 0 }\n      |                      ^~\n      |                      ;\n" } 
Sep 08 16:45:43  
Sep 08 16:45:43 =========== If your build fails, please take a look at the log above for possible reasons =========== 
Sep 08 16:45:43 Compile requests                 0 
Sep 08 16:45:43 Compile requests executed        0 
Sep 08 16:45:43 Cache hits                       0 
Sep 08 16:45:43 Cache misses                     0 
Sep 08 16:45:43 Cache timeouts                   0 
Sep 08 16:45:43 Cache read errors                0 
Sep 08 16:45:43 Forced recaches                  0 
Sep 08 16:45:43 Cache write errors               0 

See CircleCI build pytorch_linux_xenial_py3_6_gcc5_4_build (9/10)

Step: "(Optional) Merge target branch" (full log | diagnosis details | 🔁 rerun)

Automatic merge failed; fix conflicts and then commit the result.
CONFLICT (add/add): Merge conflict in .circleci/scripts/binary_linux_build.sh 
Auto-merging .circleci/scripts/binary_linux_build.sh 
CONFLICT (add/add): Merge conflict in .circleci/config.yml 
Auto-merging .circleci/config.yml 
CONFLICT (add/add): Merge conflict in .circleci/cimodel/data/pytorch_build_definitions.py 
Auto-merging .circleci/cimodel/data/pytorch_build_definitions.py 
CONFLICT (add/add): Merge conflict in .circleci/cimodel/data/pytorch_build_data.py 
Auto-merging .circleci/cimodel/data/pytorch_build_data.py 
CONFLICT (add/add): Merge conflict in .circleci/cimodel/data/binary_build_data.py 
Auto-merging .circleci/cimodel/data/binary_build_data.py 
Automatic merge failed; fix conflicts and then commit the result. 

See CircleCI build pytorch_xla_linux_bionic_py3_6_clang9_build (10/10)

Step: "(Optional) Merge target branch" (full log | diagnosis details | 🔁 rerun)

Automatic merge failed; fix conflicts and then commit the result.
CONFLICT (add/add): Merge conflict in .circleci/scripts/binary_linux_build.sh 
Auto-merging .circleci/scripts/binary_linux_build.sh 
CONFLICT (add/add): Merge conflict in .circleci/config.yml 
Auto-merging .circleci/config.yml 
CONFLICT (add/add): Merge conflict in .circleci/cimodel/data/pytorch_build_definitions.py 
Auto-merging .circleci/cimodel/data/pytorch_build_definitions.py 
CONFLICT (add/add): Merge conflict in .circleci/cimodel/data/pytorch_build_data.py 
Auto-merging .circleci/cimodel/data/pytorch_build_data.py 
CONFLICT (add/add): Merge conflict in .circleci/cimodel/data/binary_build_data.py 
Auto-merging .circleci/cimodel/data/binary_build_data.py 
Automatic merge failed; fix conflicts and then commit the result. 

❄️ 1 failure tentatively classified as flaky

but reruns have not yet been triggered to confirm:

See CircleCI build pytorch_linux_xenial_cuda10_2_cudnn7_py3_gcc7_test (1/1)

Step: "Run tests" (full log | diagnosis details | 🔁 rerun) ❄️

Sep 08 18:20:20 ConnectionResetError: [Errno 104] Connection reset by peer
Sep 08 18:20:20   File "/opt/conda/lib/python3.6/multiprocessing/connection.py", line 455, in accept 
Sep 08 18:20:20     deliver_challenge(c, self._authkey) 
Sep 08 18:20:20   File "/opt/conda/lib/python3.6/multiprocessing/connection.py", line 722, in deliver_challenge 
Sep 08 18:20:20     response = connection.recv_bytes(256)        # reject large message 
Sep 08 18:20:20   File "/opt/conda/lib/python3.6/multiprocessing/connection.py", line 216, in recv_bytes 
Sep 08 18:20:20     buf = self._recv_bytes(maxlength) 
Sep 08 18:20:20   File "/opt/conda/lib/python3.6/multiprocessing/connection.py", line 407, in _recv_bytes 
Sep 08 18:20:20     buf = self._recv(4) 
Sep 08 18:20:20   File "/opt/conda/lib/python3.6/multiprocessing/connection.py", line 379, in _recv 
Sep 08 18:20:20     chunk = read(handle, remaining) 
Sep 08 18:20:20 ConnectionResetError: [Errno 104] Connection reset by peer 
Sep 08 18:20:20 /opt/conda/lib/python3.6/multiprocessing/semaphore_tracker.py:143: UserWarning: semaphore_tracker: There appear to be 14 leaked semaphores to clean up at shutdown 
Sep 08 18:20:20   len(cache)) 
Sep 08 18:20:22 Process ErrorTrackingProcess-156: 
Sep 08 18:20:22 Traceback (most recent call last): 
Sep 08 18:20:22   File "/opt/conda/lib/python3.6/multiprocessing/process.py", line 258, in _bootstrap 
Sep 08 18:20:22     self.run() 
Sep 08 18:20:22   File "/var/lib/jenkins/workspace/test/test_dataloader.py", line 361, in run 
Sep 08 18:20:22     super(ErrorTrackingProcess, self).run() 
Sep 08 18:20:22   File "/opt/conda/lib/python3.6/multiprocessing/process.py", line 93, in run 
Sep 08 18:20:22     self._target(*self._args, **self._kwargs) 

ci.pytorch.org: 1 failed


This comment was automatically generated by Dr. CI (expand for details).Follow this link to opt-out of these comments for your Pull Requests.

Please report bugs/suggestions on the GitHub issue tracker or post in the (internal) Dr. CI Users group.

See how this bot performed.

This comment has been revised 45 times.



/**
* Whether or not the imaginary part of the tensor should be conjugated
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Technically, you only conjugate a complex number; conjugation involves negating the imaginary part of the tensor.

void set_conjugate(bool value) {
conjugate_ = value;
if (conjugate_)
key_set_ = key_set_.add(DispatchKey::Named);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Named? ;)


// This fallback effectively takes all tensors in the stack
// with their conjugate bit set, and runs conjugation on them
void conjugateFallback(const c10::OperatorHandle& op, torch::jit::Stack* stack);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

add a CAFFE2_API to the beginning of this function, that will fix your linker error. (We hide symbols by default to make our behavior like Windows, and that means you need to explicitly annotate functions which will be accessible from other files).

However, even better would be to directly register the fallback in ConjugateFallback.cpp using a TORCH_LIBRARY_IMPL block, rather than in the test; after all, you want the registration to be available for all code!

}

TEST(BackendFallbackTest, ConjugateTest) {
auto m = MAKE_TORCH_LIBRARY_IMPL(_, TESTING_ONLY_GenericMode);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You're registering to the wrong key, should be Conjugate

auto m = MAKE_TORCH_LIBRARY_IMPL(_, TESTING_ONLY_GenericMode);
m.fallback(torch::CppFunction::makeFromBoxedFunction<&conjugateFallback>());

c10::impl::IncludeDispatchKeyGuard guard(DispatchKey::Conjugate);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You don't need this. Conjugate key is on the tensor, so that's how you will get to the conjugate implementation. Check the dispatcher slides again about this :)

bdhirsh added a commit that referenced this pull request Aug 29, 2020
…_LIBRARY_IMPL to register my conjugate fallback (plus other minor fixes)

I'm still hitting issues- tracing through my failing test, I see two issues:
- my call to 'a.conj_view() + b' fails with "There were no tensor arguments to this function...", which sounds vaguely to me like I corrupted the call stack in some way in my conjugateFallback function call
- Tracing through a call to conjugateFallback(), I see that none of the ivalue arguments on the stack are actually tensors (ivalue.isTensor() == false), so the function just pops off all of the arguments and pushes then back on.

first cut of adding backend fallback for conjugation

ghstack-source-id: 4fc04d5
Pull Request resolved: #43702
void conjugateFallback(const c10::OperatorHandle& op, torch::jit::Stack* stack) {
// Unwrap all arguments
const auto num_arguments = op.schema().arguments().size();
const auto arguments = torch::jit::last(stack, num_arguments);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

torch::jit::last returns a non-owning reference to the arguments. When you pop them later in line 14, this reference is no longer valid; therefore your memory corruption.

bdhirsh added a commit that referenced this pull request Aug 31, 2020
fixed previous issue and generally cleaned up comments/code. I had to define conj_materialize() inside of the ConjugateFallback file. Before that I tried adding it directly to native_functions.yml (which we probably don't want anyway), which caused the dispatcher to infinite loop (repeatedly alternating between conj_materialize() and conjugateFallback())

Responding to the first set of PR feedback. Main change = using TORCH_LIBRARY_IMPL to register my conjugate fallback (plus other minor fixes)
I'm still hitting issues- tracing through my failing test, I see two issues:
- my call to 'a.conj_view() + b' fails with "There were no tensor arguments to this function...", which sounds vaguely to me like I corrupted the call stack in some way in my conjugateFallback function call
- Tracing through a call to conjugateFallback(), I see that none of the ivalue arguments on the stack are actually tensors (ivalue.isTensor() == false), so the function just pops off all of the arguments and pushes then back on.

first cut of adding backend fallback for conjugation

ghstack-source-id: 0f01918
Pull Request resolved: #43702
@ezyang
Copy link
Contributor

ezyang commented Sep 1, 2020

fyi ci failures are real

..\aten\src\ATen\core\dispatch\backend_fallback_test.cpp(87): error C2882: 'torch': illegal use of namespace identifier in expression
..\aten\src\ATen\core\dispatch\backend_fallback_test.cpp(88): error C2882: 'torch': illegal use of namespace identifier in expression


Tensor conj_materialize(const Tensor& self) {
// this function assumes that the tensor input has it's conjugate bit set
self.set_conjugate(false);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is not actually sound because it will affect all other references to the tensor; you need to allocate a new tensor (sharing storage/sizes/strides) and then set the conjugate bit to false and then pass it to conj. Or maybe we come up with an alternate way to call self.conj() that bypasses the conjugate flag and interprets the tensor in the "original" way (similar to how AutoNonVariableTypeMode works).

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

One reason I'm a little hesitant about a NoConjMode is that it makes the invariants on the system a little more difficult to understand. With your current system, it's easy to see how conjugate bits relate to the Conjugate dispatch key: before running the Conjugate dispatch key, you may have tensors with conjugate bit set; after the dispatch key, nothing has conjugate bit set. With NoConjMode, you may selectively "let tensors with conjugate bit" through to the backend, and so invariants get a little harder to see.

That being said, it's a little hard for me to understand how to structure things in a reasonable way without letting the conjugate bit pass through to the backend. For example, matrix multiply with fused conjugation should definitely be implemented in the CPU/CUDA backend, so we definitely have to pass the conjugate bit all the way to matmul. So perhaps the invariant isn't worth preserving.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It sounds like you're saying that we can lump backend kernels into two groups:

  • conjugation-aware: a (relatively small number of) backend/op pairs that can take in a conjugate flag, and do conjugation along with the backend operation (multiplication on CPU/CUDA) in a performant way (something we want to take advantage of). We DO want to pass the conjugate bit pass through to the backend in this case.
  • conjugation-unaware: kernels that require conjugation to be performed separately (and part of the questions is whether we want to explicit do the conjugation in the kernel, or move the logic somewhere else like I'm doing in this PR). It would be (maybe) preferable to not pass the conjugate bit through to the backend, to make the system easier to reason about.

Would it be reasonable to make the dispatcher aware of the specific op/backend pairs that can take in a conjugation flag, and know to only call out to my conjugate fallback if (a) the conjugate key is set on a tensor input, and (b) the particular op/backend pair currently being called is not in that first category?

I guess that would make this conjugateFallback kernel not really a "fallback" anymore. More like a fallback "for everything except this whitelist of op/backend pairs", which also might become pretty hard to reason about.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It would be (maybe) preferable to not pass the conjugate bit through to the backend, to make the system easier to reason about.

Yes. The only trouble is if you "accidentally" pass along a tensor with the conjugate bit to the backend, and if there is no sanity checking in the back, you will end up with hilarious bugs where the result is the conjugate of the expected result.

Would it be reasonable to make the dispatcher aware of the specific op/backend pairs that can take in a conjugation flag, and know to only call out to my conjugate fallback if (a) the conjugate key is set on a tensor input, and (b) the particular op/backend pair currently being called is not in that first category?

Yes. In fact, we can do this by registering fallthrough functions to operations that support conjugation bit directly. It is probably not worth splitting this into a per-backend concept; it is probably better being an all or nothing deal.

More like a fallback "for everything except this whitelist of op/backend pairs", which also might become pretty hard to reason about.

I agree. Which is why @zdevito's suggestion to just call materialize_conj manually in all the kernels is appealing. After all, you already had to do some coding to make complex work, this is just one more thing.

@ezyang
Copy link
Contributor

ezyang commented Sep 1, 2020

This prototype looks great. We'll talk about what we should do next with this in our sync shortly.

cc @anjali411

auto conjugated_tensor = conj_materialize(tensor);
torch::jit::push(stack, conjugated_tensor);
} else {
torch::jit::push(stack, ivalue);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: You can improve the performance of this code by being more careful about moving ivalues on and off the argument stack. This will help you avoid unnecessary reference count bumps. So you would move an argument off the arguments list, and the move it back onto the stack when you push it.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm going to try to fix this: instead of popping all arguments off the stack and pushing them all back on (most of which are unchanged except for the conjugated tensor), I'm just going to selectively overwrite the memory in the stack (vector) for the tensors that need to be conjugated. Let me know if that's what you had in mind.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That will be even better :)

use_c10_dispatcher: full
variants: function, method

- func: conj_view(Tensor self) -> Tensor
Copy link
Contributor

@ezyang ezyang Sep 1, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@bdhirsh asked why we named this separately as conj_view. Based on discussion with @anjali411 and co at #43270 we are seriously considering making this implementation conj (and so the actual kernel that converts the conjugation out-of-place will need to be given a different name)


// Whether or not to conjugate the imaginary part of the tensor
// TODO: is there a reason that we need this bool?
// Doesn't the existence of the conjugate dispatch key give us everything we need?
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

From in person discussion: No, it's probably fine to omit this boolean here since it's in the dispatch key set.

@ezyang
Copy link
Contributor

ezyang commented Sep 1, 2020

Some notes about what else we'd need to do to land this:

  1. Make autograd work on conj_view (it's a view operation, so it needs to be treated specially in the same way other view operations are treated)
  2. Actually get some payoff by fusing conjugation with matrix multiply (and work out the idiom for how to bypass the fallback to the underlying layer)
  3. Seriously consider @zdevito's suggestion to NOT do this with a backend fallback, and instead insert explicit materializations in kernels that support complex numbers (this would be a bit more legwork)
  4. Bikeshed names (like conj_materialize)
  5. Make sure inplace operations on conjugate views do something reasonable, or make sure these raise an error saying that it is not supported without an explicit materialize.
  6. A more detailed design as we discover other edge cases to deal with

bdhirsh added a commit that referenced this pull request Sep 8, 2020
fixed previous issue and generally cleaned up comments/code. I had to define conj_materialize() inside of the ConjugateFallback file. Before that I tried adding it directly to native_functions.yml (which we probably don't want anyway), which caused the dispatcher to infinite loop (repeatedly alternating between conj_materialize() and conjugateFallback())

Responding to the first set of PR feedback. Main change = using TORCH_LIBRARY_IMPL to register my conjugate fallback (plus other minor fixes)
I'm still hitting issues- tracing through my failing test, I see two issues:
- my call to 'a.conj_view() + b' fails with "There were no tensor arguments to this function...", which sounds vaguely to me like I corrupted the call stack in some way in my conjugateFallback function call
- Tracing through a call to conjugateFallback(), I see that none of the ivalue arguments on the stack are actually tensors (ivalue.isTensor() == false), so the function just pops off all of the arguments and pushes then back on.

first cut of adding backend fallback for conjugation

ghstack-source-id: 162a424
Pull Request resolved: #43702

cr feedback: (a) trying an exclude guard instead of removing the conjugate flag directly on the tensor, (b) make conjugate fallback more efficient, (c) removed unnecessary bool conjugate_ plus some other cleanup
@ezyang
Copy link
Contributor

ezyang commented Sep 16, 2020

Some more updates having worked on this on Tuesday:

  • Integration internally is pretty involved, so I think zdevito definitely was right and we shouldn't use the backend fallback for core (might be possible to use for custom kernels, but chances are they just won't support complex at all)
  • Need to update TensorIterator to directly understand conj materialization, it's a bit like dtype promotion
  • copy_ is the most important function; it effectively is how you do conjugation in the new world order
  • There is a bunch of blas stuff we aren't covering properly that needs to be fixed. E.g. cblas_gemv is not being used for gemv on complex on CPU #44741

anjali411 added a commit that referenced this pull request Apr 23, 2021
Based off of ezyang (#44799) and bdhirsh (#43702) 's prototype:  

TODO:
1. docs
2. handle tensorlist

[ghstack-poisoned]
anjali411 added a commit that referenced this pull request Apr 26, 2021
Based off of ezyang (#44799) and bdhirsh (#43702) 's prototype:  

TODO:
1. docs
2. handle tensorlist

[ghstack-poisoned]
anjali411 added a commit that referenced this pull request Apr 26, 2021
Based off of ezyang (#44799) and bdhirsh (#43702) 's prototype:  

TODO:
1. docs
2. handle tensorlist

[ghstack-poisoned]
anjali411 added a commit that referenced this pull request Apr 27, 2021
Based off of ezyang (#44799) and bdhirsh (#43702) 's prototype:  

TODO:
1. docs
2. handle tensorlist

[ghstack-poisoned]
anjali411 added a commit that referenced this pull request Apr 27, 2021
Based off of ezyang (#44799) and bdhirsh (#43702) 's prototype:  

TODO:
1. docs
2. handle tensorlist

[ghstack-poisoned]
anjali411 added a commit that referenced this pull request Apr 27, 2021
Based off of ezyang (#44799) and bdhirsh (#43702) 's prototype:  

TODO:
1. docs
2. handle tensorlist

[ghstack-poisoned]
anjali411 added a commit that referenced this pull request Apr 27, 2021
Based off of ezyang (#44799) and bdhirsh (#43702) 's prototype:  

TODO:
1. conjugate view RFC
2. Enable conj view testing for all OpInfos

[ghstack-poisoned]
anjali411 added a commit that referenced this pull request Apr 27, 2021
Based off of ezyang (#44799) and bdhirsh (#43702) 's prototype:  

TODO:
1. conjugate view RFC
2. Enable conj view testing for all OpInfos

[ghstack-poisoned]
anjali411 added a commit that referenced this pull request Apr 27, 2021
Based off of ezyang (#44799) and bdhirsh (#43702) 's prototype:  

TODO:
1. conjugate view RFC
2. Enable conj view testing for all OpInfos

[ghstack-poisoned]
anjali411 added a commit that referenced this pull request Apr 28, 2021
Based off of ezyang (#44799) and bdhirsh (#43702) 's prototype:  

TODO:
1. conjugate view RFC
2. Enable conj view testing for all OpInfos

[ghstack-poisoned]
anjali411 added a commit that referenced this pull request Apr 29, 2021
Based off of ezyang (#44799) and bdhirsh (#43702) 's prototype:  

Here's a summary of the changes in this PR:
This PR adds a new dispatch key called Conjugate. This enables us to make conjugate operation a view and leverage the specialized library functions that fast path with the hermitian operation (conj + transpose). 

1. Conjugate operation will now return a view with conj bit (1) for complex tensors and returns self for non-complex tensors as before. This also means `torch.view_as_real` will no longer be a view on conjugated complex tensors and is hence disabled. To fill the gap, we have added `torch.view_as_real_physical` which would return the real tensor agnostic of the conjugate bit on the input complex tensor. The information about conjugation on the old tensor can be obtained by calling `.is_conj()` on the new tensor.
2. NEW API: 
    a) `.conj()` -- now returning a view.
    b) `.conj_physical()` -- does the physical conjugate operation. If the conj bit for input was set, you'd get `self.clone()`, else you'll get a new tensor with conjugated value in its memory.
    c) `.conj_physical_()`, and `out=` variant
    d) `.resolve_conj()`  -- materializes the conjugation. returns self if the conj bit is unset, else returns a new tensor with conjugated values and conj bit set to 0.
    e) `.resolve_conj_()` in-place version of (d)
    f) `view_as_real_physical` -- as described in (1), it's functionally same as `view_as_real`, just that it doesn't error out on conjugated tensors.
    g) `view_as_real` -- existing function, but now errors out on conjugated tensors.
3. Conjugate Fallback
    a) Vast majority of PyTorch functions would currently use this fallback when they are called on a conjugated tensor.
    b) This fallback is well equipped to handle the following cases:
        - functional operation e.g., `torch.sin(input)`
        - Mutable inputs and in-place operations e.g., `tensor.add_(2)`
        - out-of-place operation e.g., `torch.sin(input, out=out)`
        - Tensorlist input args
        - NOTE: Meta tensors don't work with conjugate fallback.
4. Autograd
    a) `resolve_conj()` is an identity function w.r.t. autograd
    b) Everything else works as expected.
5. Testing: 
    a) All method_tests run with conjugate view tensors.
    b) OpInfo tests that run with conjugate views
        - test_variant_consistency_eager/jit
        - gradcheck, gradgradcheck
        - test_conj_views (that only run for `torch.cfloat` dtype)

TODO:
1. conjugate view RFC


[ghstack-poisoned]
anjali411 added a commit that referenced this pull request Apr 29, 2021
Based off of ezyang (#44799) and bdhirsh (#43702) 's prototype:  

Here's a summary of the changes in this PR:
This PR adds a new dispatch key called Conjugate. This enables us to make conjugate operation a view and leverage the specialized library functions that fast path with the hermitian operation (conj + transpose). 

1. Conjugate operation will now return a view with conj bit (1) for complex tensors and returns self for non-complex tensors as before. This also means `torch.view_as_real` will no longer be a view on conjugated complex tensors and is hence disabled. To fill the gap, we have added `torch.view_as_real_physical` which would return the real tensor agnostic of the conjugate bit on the input complex tensor. The information about conjugation on the old tensor can be obtained by calling `.is_conj()` on the new tensor.
2. NEW API: 
    a) `.conj()` -- now returning a view.
    b) `.conj_physical()` -- does the physical conjugate operation. If the conj bit for input was set, you'd get `self.clone()`, else you'll get a new tensor with conjugated value in its memory.
    c) `.conj_physical_()`, and `out=` variant
    d) `.resolve_conj()`  -- materializes the conjugation. returns self if the conj bit is unset, else returns a new tensor with conjugated values and conj bit set to 0.
    e) `.resolve_conj_()` in-place version of (d)
    f) `view_as_real_physical` -- as described in (1), it's functionally same as `view_as_real`, just that it doesn't error out on conjugated tensors.
    g) `view_as_real` -- existing function, but now errors out on conjugated tensors.
3. Conjugate Fallback
    a) Vast majority of PyTorch functions would currently use this fallback when they are called on a conjugated tensor.
    b) This fallback is well equipped to handle the following cases:
        - functional operation e.g., `torch.sin(input)`
        - Mutable inputs and in-place operations e.g., `tensor.add_(2)`
        - out-of-place operation e.g., `torch.sin(input, out=out)`
        - Tensorlist input args
        - NOTE: Meta tensors don't work with conjugate fallback.
4. Autograd
    a) `resolve_conj()` is an identity function w.r.t. autograd
    b) Everything else works as expected.
5. Testing: 
    a) All method_tests run with conjugate view tensors.
    b) OpInfo tests that run with conjugate views
        - test_variant_consistency_eager/jit
        - gradcheck, gradgradcheck
        - test_conj_views (that only run for `torch.cfloat` dtype)

TODO:
1. conjugate view RFC


[ghstack-poisoned]
anjali411 added a commit that referenced this pull request Apr 29, 2021
Based off of ezyang (#44799) and bdhirsh (#43702) 's prototype:  

Here's a summary of the changes in this PR:
This PR adds a new dispatch key called Conjugate. This enables us to make conjugate operation a view and leverage the specialized library functions that fast path with the hermitian operation (conj + transpose). 

1. Conjugate operation will now return a view with conj bit (1) for complex tensors and returns self for non-complex tensors as before. This also means `torch.view_as_real` will no longer be a view on conjugated complex tensors and is hence disabled. To fill the gap, we have added `torch.view_as_real_physical` which would return the real tensor agnostic of the conjugate bit on the input complex tensor. The information about conjugation on the old tensor can be obtained by calling `.is_conj()` on the new tensor.
2. NEW API: 
    a) `.conj()` -- now returning a view.
    b) `.conj_physical()` -- does the physical conjugate operation. If the conj bit for input was set, you'd get `self.clone()`, else you'll get a new tensor with conjugated value in its memory.
    c) `.conj_physical_()`, and `out=` variant
    d) `.resolve_conj()`  -- materializes the conjugation. returns self if the conj bit is unset, else returns a new tensor with conjugated values and conj bit set to 0.
    e) `.resolve_conj_()` in-place version of (d)
    f) `view_as_real_physical` -- as described in (1), it's functionally same as `view_as_real`, just that it doesn't error out on conjugated tensors.
    g) `view_as_real` -- existing function, but now errors out on conjugated tensors.
3. Conjugate Fallback
    a) Vast majority of PyTorch functions would currently use this fallback when they are called on a conjugated tensor.
    b) This fallback is well equipped to handle the following cases:
        - functional operation e.g., `torch.sin(input)`
        - Mutable inputs and in-place operations e.g., `tensor.add_(2)`
        - out-of-place operation e.g., `torch.sin(input, out=out)`
        - Tensorlist input args
        - NOTE: Meta tensors don't work with conjugate fallback.
4. Autograd
    a) `resolve_conj()` is an identity function w.r.t. autograd
    b) Everything else works as expected.
5. Testing: 
    a) All method_tests run with conjugate view tensors.
    b) OpInfo tests that run with conjugate views
        - test_variant_consistency_eager/jit
        - gradcheck, gradgradcheck
        - test_conj_views (that only run for `torch.cfloat` dtype)
 
NOTE: functions like `empty_like`, `zero_like`, `randn_like`, `clone` don't propagate the conjugate bit.

Follow up work:
1. conjugate view RFC
2. Add neg bit to re-enable view operation on conjugated tensors
3. Update linalg functions to call into specialized functions that fast path with the hermitian operation. 


[ghstack-poisoned]
anjali411 added a commit that referenced this pull request Apr 30, 2021
Based off of ezyang (#44799) and bdhirsh (#43702) 's prototype:  

Here's a summary of the changes in this PR:
This PR adds a new dispatch key called Conjugate. This enables us to make conjugate operation a view and leverage the specialized library functions that fast path with the hermitian operation (conj + transpose). 

1. Conjugate operation will now return a view with conj bit (1) for complex tensors and returns self for non-complex tensors as before. This also means `torch.view_as_real` will no longer be a view on conjugated complex tensors and is hence disabled. To fill the gap, we have added `torch.view_as_real_physical` which would return the real tensor agnostic of the conjugate bit on the input complex tensor. The information about conjugation on the old tensor can be obtained by calling `.is_conj()` on the new tensor.
2. NEW API: 
    a) `.conj()` -- now returning a view.
    b) `.conj_physical()` -- does the physical conjugate operation. If the conj bit for input was set, you'd get `self.clone()`, else you'll get a new tensor with conjugated value in its memory.
    c) `.conj_physical_()`, and `out=` variant
    d) `.resolve_conj()`  -- materializes the conjugation. returns self if the conj bit is unset, else returns a new tensor with conjugated values and conj bit set to 0.
    e) `.resolve_conj_()` in-place version of (d)
    f) `view_as_real_physical` -- as described in (1), it's functionally same as `view_as_real`, just that it doesn't error out on conjugated tensors.
    g) `view_as_real` -- existing function, but now errors out on conjugated tensors.
3. Conjugate Fallback
    a) Vast majority of PyTorch functions would currently use this fallback when they are called on a conjugated tensor.
    b) This fallback is well equipped to handle the following cases:
        - functional operation e.g., `torch.sin(input)`
        - Mutable inputs and in-place operations e.g., `tensor.add_(2)`
        - out-of-place operation e.g., `torch.sin(input, out=out)`
        - Tensorlist input args
        - NOTE: Meta tensors don't work with conjugate fallback.
4. Autograd
    a) `resolve_conj()` is an identity function w.r.t. autograd
    b) Everything else works as expected.
5. Testing: 
    a) All method_tests run with conjugate view tensors.
    b) OpInfo tests that run with conjugate views
        - test_variant_consistency_eager/jit
        - gradcheck, gradgradcheck
        - test_conj_views (that only run for `torch.cfloat` dtype)
 
NOTE: functions like `empty_like`, `zero_like`, `randn_like`, `clone` don't propagate the conjugate bit.

Follow up work:
1. conjugate view RFC
2. Add neg bit to re-enable view operation on conjugated tensors
3. Update linalg functions to call into specialized functions that fast path with the hermitian operation. 


[ghstack-poisoned]
anjali411 added a commit that referenced this pull request Apr 30, 2021
Based off of ezyang (#44799) and bdhirsh (#43702) 's prototype:  

Here's a summary of the changes in this PR:
This PR adds a new dispatch key called Conjugate. This enables us to make conjugate operation a view and leverage the specialized library functions that fast path with the hermitian operation (conj + transpose). 

1. Conjugate operation will now return a view with conj bit (1) for complex tensors and returns self for non-complex tensors as before. This also means `torch.view_as_real` will no longer be a view on conjugated complex tensors and is hence disabled. To fill the gap, we have added `torch.view_as_real_physical` which would return the real tensor agnostic of the conjugate bit on the input complex tensor. The information about conjugation on the old tensor can be obtained by calling `.is_conj()` on the new tensor.
2. NEW API: 
    a) `.conj()` -- now returning a view.
    b) `.conj_physical()` -- does the physical conjugate operation. If the conj bit for input was set, you'd get `self.clone()`, else you'll get a new tensor with conjugated value in its memory.
    c) `.conj_physical_()`, and `out=` variant
    d) `.resolve_conj()`  -- materializes the conjugation. returns self if the conj bit is unset, else returns a new tensor with conjugated values and conj bit set to 0.
    e) `.resolve_conj_()` in-place version of (d)
    f) `view_as_real_physical` -- as described in (1), it's functionally same as `view_as_real`, just that it doesn't error out on conjugated tensors.
    g) `view_as_real` -- existing function, but now errors out on conjugated tensors.
3. Conjugate Fallback
    a) Vast majority of PyTorch functions would currently use this fallback when they are called on a conjugated tensor.
    b) This fallback is well equipped to handle the following cases:
        - functional operation e.g., `torch.sin(input)`
        - Mutable inputs and in-place operations e.g., `tensor.add_(2)`
        - out-of-place operation e.g., `torch.sin(input, out=out)`
        - Tensorlist input args
        - NOTE: Meta tensors don't work with conjugate fallback.
4. Autograd
    a) `resolve_conj()` is an identity function w.r.t. autograd
    b) Everything else works as expected.
5. Testing: 
    a) All method_tests run with conjugate view tensors.
    b) OpInfo tests that run with conjugate views
        - test_variant_consistency_eager/jit
        - gradcheck, gradgradcheck
        - test_conj_views (that only run for `torch.cfloat` dtype)
 
NOTE: functions like `empty_like`, `zero_like`, `randn_like`, `clone` don't propagate the conjugate bit.

Follow up work:
1. conjugate view RFC
2. Add neg bit to re-enable view operation on conjugated tensors
3. Update linalg functions to call into specialized functions that fast path with the hermitian operation. 


[ghstack-poisoned]
anjali411 added a commit that referenced this pull request Apr 30, 2021
Based off of ezyang (#44799) and bdhirsh (#43702) 's prototype:  

Here's a summary of the changes in this PR:
This PR adds a new dispatch key called Conjugate. This enables us to make conjugate operation a view and leverage the specialized library functions that fast path with the hermitian operation (conj + transpose). 

1. Conjugate operation will now return a view with conj bit (1) for complex tensors and returns self for non-complex tensors as before. This also means `torch.view_as_real` will no longer be a view on conjugated complex tensors and is hence disabled. To fill the gap, we have added `torch.view_as_real_physical` which would return the real tensor agnostic of the conjugate bit on the input complex tensor. The information about conjugation on the old tensor can be obtained by calling `.is_conj()` on the new tensor.
2. NEW API: 
    a) `.conj()` -- now returning a view.
    b) `.conj_physical()` -- does the physical conjugate operation. If the conj bit for input was set, you'd get `self.clone()`, else you'll get a new tensor with conjugated value in its memory.
    c) `.conj_physical_()`, and `out=` variant
    d) `.resolve_conj()`  -- materializes the conjugation. returns self if the conj bit is unset, else returns a new tensor with conjugated values and conj bit set to 0.
    e) `.resolve_conj_()` in-place version of (d)
    f) `view_as_real_physical` -- as described in (1), it's functionally same as `view_as_real`, just that it doesn't error out on conjugated tensors.
    g) `view_as_real` -- existing function, but now errors out on conjugated tensors.
3. Conjugate Fallback
    a) Vast majority of PyTorch functions would currently use this fallback when they are called on a conjugated tensor.
    b) This fallback is well equipped to handle the following cases:
        - functional operation e.g., `torch.sin(input)`
        - Mutable inputs and in-place operations e.g., `tensor.add_(2)`
        - out-of-place operation e.g., `torch.sin(input, out=out)`
        - Tensorlist input args
        - NOTE: Meta tensors don't work with conjugate fallback.
4. Autograd
    a) `resolve_conj()` is an identity function w.r.t. autograd
    b) Everything else works as expected.
5. Testing: 
    a) All method_tests run with conjugate view tensors.
    b) OpInfo tests that run with conjugate views
        - test_variant_consistency_eager/jit
        - gradcheck, gradgradcheck
        - test_conj_views (that only run for `torch.cfloat` dtype)
 
NOTE: functions like `empty_like`, `zero_like`, `randn_like`, `clone` don't propagate the conjugate bit.

Follow up work:
1. conjugate view RFC
2. Add neg bit to re-enable view operation on conjugated tensors
3. Update linalg functions to call into specialized functions that fast path with the hermitian operation. 


[ghstack-poisoned]
anjali411 added a commit that referenced this pull request Apr 30, 2021
Based off of ezyang (#44799) and bdhirsh (#43702) 's prototype:  

Here's a summary of the changes in this PR:
This PR adds a new dispatch key called Conjugate. This enables us to make conjugate operation a view and leverage the specialized library functions that fast path with the hermitian operation (conj + transpose). 

1. Conjugate operation will now return a view with conj bit (1) for complex tensors and returns self for non-complex tensors as before. This also means `torch.view_as_real` will no longer be a view on conjugated complex tensors and is hence disabled. To fill the gap, we have added `torch.view_as_real_physical` which would return the real tensor agnostic of the conjugate bit on the input complex tensor. The information about conjugation on the old tensor can be obtained by calling `.is_conj()` on the new tensor.
2. NEW API: 
    a) `.conj()` -- now returning a view.
    b) `.conj_physical()` -- does the physical conjugate operation. If the conj bit for input was set, you'd get `self.clone()`, else you'll get a new tensor with conjugated value in its memory.
    c) `.conj_physical_()`, and `out=` variant
    d) `.resolve_conj()`  -- materializes the conjugation. returns self if the conj bit is unset, else returns a new tensor with conjugated values and conj bit set to 0.
    e) `.resolve_conj_()` in-place version of (d)
    f) `view_as_real_physical` -- as described in (1), it's functionally same as `view_as_real`, just that it doesn't error out on conjugated tensors.
    g) `view_as_real` -- existing function, but now errors out on conjugated tensors.
3. Conjugate Fallback
    a) Vast majority of PyTorch functions would currently use this fallback when they are called on a conjugated tensor.
    b) This fallback is well equipped to handle the following cases:
        - functional operation e.g., `torch.sin(input)`
        - Mutable inputs and in-place operations e.g., `tensor.add_(2)`
        - out-of-place operation e.g., `torch.sin(input, out=out)`
        - Tensorlist input args
        - NOTE: Meta tensors don't work with conjugate fallback.
4. Autograd
    a) `resolve_conj()` is an identity function w.r.t. autograd
    b) Everything else works as expected.
5. Testing: 
    a) All method_tests run with conjugate view tensors.
    b) OpInfo tests that run with conjugate views
        - test_variant_consistency_eager/jit
        - gradcheck, gradgradcheck
        - test_conj_views (that only run for `torch.cfloat` dtype)
 
NOTE: functions like `empty_like`, `zero_like`, `randn_like`, `clone` don't propagate the conjugate bit.

Follow up work:
1. conjugate view RFC
2. Add neg bit to re-enable view operation on conjugated tensors
3. Update linalg functions to call into specialized functions that fast path with the hermitian operation. 


[ghstack-poisoned]
anjali411 added a commit that referenced this pull request May 5, 2021
Based off of ezyang (#44799) and bdhirsh (#43702) 's prototype:  

Here's a summary of the changes in this PR:
This PR adds a new dispatch key called Conjugate. This enables us to make conjugate operation a view and leverage the specialized library functions that fast path with the hermitian operation (conj + transpose). 

1. Conjugate operation will now return a view with conj bit (1) for complex tensors and returns self for non-complex tensors as before. This also means `torch.view_as_real` will no longer be a view on conjugated complex tensors and is hence disabled. To fill the gap, we have added `torch.view_as_real_physical` which would return the real tensor agnostic of the conjugate bit on the input complex tensor. The information about conjugation on the old tensor can be obtained by calling `.is_conj()` on the new tensor.
2. NEW API: 
    a) `.conj()` -- now returning a view.
    b) `.conj_physical()` -- does the physical conjugate operation. If the conj bit for input was set, you'd get `self.clone()`, else you'll get a new tensor with conjugated value in its memory.
    c) `.conj_physical_()`, and `out=` variant
    d) `.resolve_conj()`  -- materializes the conjugation. returns self if the conj bit is unset, else returns a new tensor with conjugated values and conj bit set to 0.
    e) `.resolve_conj_()` in-place version of (d)
    f) `view_as_real_physical` -- as described in (1), it's functionally same as `view_as_real`, just that it doesn't error out on conjugated tensors.
    g) `view_as_real` -- existing function, but now errors out on conjugated tensors.
3. Conjugate Fallback
    a) Vast majority of PyTorch functions would currently use this fallback when they are called on a conjugated tensor.
    b) This fallback is well equipped to handle the following cases:
        - functional operation e.g., `torch.sin(input)`
        - Mutable inputs and in-place operations e.g., `tensor.add_(2)`
        - out-of-place operation e.g., `torch.sin(input, out=out)`
        - Tensorlist input args
        - NOTE: Meta tensors don't work with conjugate fallback.
4. Autograd
    a) `resolve_conj()` is an identity function w.r.t. autograd
    b) Everything else works as expected.
5. Testing: 
    a) All method_tests run with conjugate view tensors.
    b) OpInfo tests that run with conjugate views
        - test_variant_consistency_eager/jit
        - gradcheck, gradgradcheck
        - test_conj_views (that only run for `torch.cfloat` dtype)
 
NOTE: functions like `empty_like`, `zero_like`, `randn_like`, `clone` don't propagate the conjugate bit.

Follow up work:
1. conjugate view RFC
2. Add neg bit to re-enable view operation on conjugated tensors
3. Update linalg functions to call into specialized functions that fast path with the hermitian operation. 


[ghstack-poisoned]
anjali411 added a commit that referenced this pull request May 5, 2021
Based off of ezyang (#44799) and bdhirsh (#43702) 's prototype:  

Here's a summary of the changes in this PR:
This PR adds a new dispatch key called Conjugate. This enables us to make conjugate operation a view and leverage the specialized library functions that fast path with the hermitian operation (conj + transpose). 

1. Conjugate operation will now return a view with conj bit (1) for complex tensors and returns self for non-complex tensors as before. This also means `torch.view_as_real` will no longer be a view on conjugated complex tensors and is hence disabled. To fill the gap, we have added `torch.view_as_real_physical` which would return the real tensor agnostic of the conjugate bit on the input complex tensor. The information about conjugation on the old tensor can be obtained by calling `.is_conj()` on the new tensor.
2. NEW API: 
    a) `.conj()` -- now returning a view.
    b) `.conj_physical()` -- does the physical conjugate operation. If the conj bit for input was set, you'd get `self.clone()`, else you'll get a new tensor with conjugated value in its memory.
    c) `.conj_physical_()`, and `out=` variant
    d) `.resolve_conj()`  -- materializes the conjugation. returns self if the conj bit is unset, else returns a new tensor with conjugated values and conj bit set to 0.
    e) `.resolve_conj_()` in-place version of (d)
    f) `view_as_real_physical` -- as described in (1), it's functionally same as `view_as_real`, just that it doesn't error out on conjugated tensors.
    g) `view_as_real` -- existing function, but now errors out on conjugated tensors.
3. Conjugate Fallback
    a) Vast majority of PyTorch functions would currently use this fallback when they are called on a conjugated tensor.
    b) This fallback is well equipped to handle the following cases:
        - functional operation e.g., `torch.sin(input)`
        - Mutable inputs and in-place operations e.g., `tensor.add_(2)`
        - out-of-place operation e.g., `torch.sin(input, out=out)`
        - Tensorlist input args
        - NOTE: Meta tensors don't work with conjugate fallback.
4. Autograd
    a) `resolve_conj()` is an identity function w.r.t. autograd
    b) Everything else works as expected.
5. Testing: 
    a) All method_tests run with conjugate view tensors.
    b) OpInfo tests that run with conjugate views
        - test_variant_consistency_eager/jit
        - gradcheck, gradgradcheck
        - test_conj_views (that only run for `torch.cfloat` dtype)
 
NOTE: functions like `empty_like`, `zero_like`, `randn_like`, `clone` don't propagate the conjugate bit.

Follow up work:
1. conjugate view RFC
2. Add neg bit to re-enable view operation on conjugated tensors
3. Update linalg functions to call into specialized functions that fast path with the hermitian operation. 


[ghstack-poisoned]
anjali411 added a commit that referenced this pull request May 5, 2021
Based off of ezyang (#44799) and bdhirsh (#43702) 's prototype:  

Here's a summary of the changes in this PR:
This PR adds a new dispatch key called Conjugate. This enables us to make conjugate operation a view and leverage the specialized library functions that fast path with the hermitian operation (conj + transpose). 

1. Conjugate operation will now return a view with conj bit (1) for complex tensors and returns self for non-complex tensors as before. This also means `torch.view_as_real` will no longer be a view on conjugated complex tensors and is hence disabled. To fill the gap, we have added `torch.view_as_real_physical` which would return the real tensor agnostic of the conjugate bit on the input complex tensor. The information about conjugation on the old tensor can be obtained by calling `.is_conj()` on the new tensor.
2. NEW API: 
    a) `.conj()` -- now returning a view.
    b) `.conj_physical()` -- does the physical conjugate operation. If the conj bit for input was set, you'd get `self.clone()`, else you'll get a new tensor with conjugated value in its memory.
    c) `.conj_physical_()`, and `out=` variant
    d) `.resolve_conj()`  -- materializes the conjugation. returns self if the conj bit is unset, else returns a new tensor with conjugated values and conj bit set to 0.
    e) `.resolve_conj_()` in-place version of (d)
    f) `view_as_real_physical` -- as described in (1), it's functionally same as `view_as_real`, just that it doesn't error out on conjugated tensors.
    g) `view_as_real` -- existing function, but now errors out on conjugated tensors.
3. Conjugate Fallback
    a) Vast majority of PyTorch functions would currently use this fallback when they are called on a conjugated tensor.
    b) This fallback is well equipped to handle the following cases:
        - functional operation e.g., `torch.sin(input)`
        - Mutable inputs and in-place operations e.g., `tensor.add_(2)`
        - out-of-place operation e.g., `torch.sin(input, out=out)`
        - Tensorlist input args
        - NOTE: Meta tensors don't work with conjugate fallback.
4. Autograd
    a) `resolve_conj()` is an identity function w.r.t. autograd
    b) Everything else works as expected.
5. Testing: 
    a) All method_tests run with conjugate view tensors.
    b) OpInfo tests that run with conjugate views
        - test_variant_consistency_eager/jit
        - gradcheck, gradgradcheck
        - test_conj_views (that only run for `torch.cfloat` dtype)
 
NOTE: functions like `empty_like`, `zero_like`, `randn_like`, `clone` don't propagate the conjugate bit.

Follow up work:
1. conjugate view RFC
2. Add neg bit to re-enable view operation on conjugated tensors
3. Update linalg functions to call into specialized functions that fast path with the hermitian operation.

Differential Revision: [D28227315](https://our.internmc.facebook.com/intern/diff/D28227315)

[ghstack-poisoned]
anjali411 added a commit that referenced this pull request May 10, 2021
Based off of ezyang (#44799) and bdhirsh (#43702) 's prototype:  

Here's a summary of the changes in this PR:
This PR adds a new dispatch key called Conjugate. This enables us to make conjugate operation a view and leverage the specialized library functions that fast path with the hermitian operation (conj + transpose). 

1. Conjugate operation will now return a view with conj bit (1) for complex tensors and returns self for non-complex tensors as before. This also means `torch.view_as_real` will no longer be a view on conjugated complex tensors and is hence disabled. To fill the gap, we have added `torch.view_as_real_physical` which would return the real tensor agnostic of the conjugate bit on the input complex tensor. The information about conjugation on the old tensor can be obtained by calling `.is_conj()` on the new tensor.
2. NEW API: 
    a) `.conj()` -- now returning a view.
    b) `.conj_physical()` -- does the physical conjugate operation. If the conj bit for input was set, you'd get `self.clone()`, else you'll get a new tensor with conjugated value in its memory.
    c) `.conj_physical_()`, and `out=` variant
    d) `.resolve_conj()`  -- materializes the conjugation. returns self if the conj bit is unset, else returns a new tensor with conjugated values and conj bit set to 0.
    e)  `view_as_real_physical` -- as described in (1), it's functionally same as `view_as_real`, just that it doesn't error out on conjugated tensors.
    g) `view_as_real` -- existing function, but now errors out on conjugated tensors.
3. Conjugate Fallback
    a) Vast majority of PyTorch functions would currently use this fallback when they are called on a conjugated tensor.
    b) This fallback is well equipped to handle the following cases:
        - functional operation e.g., `torch.sin(input)`
        - Mutable inputs and in-place operations e.g., `tensor.add_(2)`
        - out-of-place operation e.g., `torch.sin(input, out=out)`
        - Tensorlist input args
        - NOTE: Meta tensors don't work with conjugate fallback.
4. Autograd
    a) `resolve_conj()` is an identity function w.r.t. autograd
    b)  Everything else works as expected.
5. Testing: 
    a) All method_tests run with conjugate view tensors.
    b) OpInfo tests that run with conjugate views
        - test_variant_consistency_eager/jit
        - gradcheck, gradgradcheck
        - test_conj_views (that only run for `torch.cfloat` dtype)
 
NOTE: functions like `empty_like`, `zero_like`, `randn_like`, `clone` don't propagate the conjugate bit.

Follow up work:
1. conjugate view RFC
2. Add neg bit to re-enable view operation on conjugated tensors
3. Update linalg functions to call into specialized functions that fast path with the hermitian operation.

Differential Revision: [D28227315](https://our.internmc.facebook.com/intern/diff/D28227315)

[ghstack-poisoned]
anjali411 added a commit that referenced this pull request May 11, 2021
Based off of ezyang (#44799) and bdhirsh (#43702) 's prototype:  

Here's a summary of the changes in this PR:
This PR adds a new dispatch key called Conjugate. This enables us to make conjugate operation a view and leverage the specialized library functions that fast path with the hermitian operation (conj + transpose). 

1. Conjugate operation will now return a view with conj bit (1) for complex tensors and returns self for non-complex tensors as before. This also means `torch.view_as_real` will no longer be a view on conjugated complex tensors and is hence disabled. To fill the gap, we have added `torch.view_as_real_physical` which would return the real tensor agnostic of the conjugate bit on the input complex tensor. The information about conjugation on the old tensor can be obtained by calling `.is_conj()` on the new tensor.
2. NEW API: 
    a) `.conj()` -- now returning a view.
    b) `.conj_physical()` -- does the physical conjugate operation. If the conj bit for input was set, you'd get `self.clone()`, else you'll get a new tensor with conjugated value in its memory.
    c) `.conj_physical_()`, and `out=` variant
    d) `.resolve_conj()`  -- materializes the conjugation. returns self if the conj bit is unset, else returns a new tensor with conjugated values and conj bit set to 0.
    e)  `view_as_real_physical` -- as described in (1), it's functionally same as `view_as_real`, just that it doesn't error out on conjugated tensors.
    g) `view_as_real` -- existing function, but now errors out on conjugated tensors.
3. Conjugate Fallback
    a) Vast majority of PyTorch functions would currently use this fallback when they are called on a conjugated tensor.
    b) This fallback is well equipped to handle the following cases:
        - functional operation e.g., `torch.sin(input)`
        - Mutable inputs and in-place operations e.g., `tensor.add_(2)`
        - out-of-place operation e.g., `torch.sin(input, out=out)`
        - Tensorlist input args
        - NOTE: Meta tensors don't work with conjugate fallback.
4. Autograd
    a) `resolve_conj()` is an identity function w.r.t. autograd
    b)  Everything else works as expected.
5. Testing: 
    a) All method_tests run with conjugate view tensors.
    b) OpInfo tests that run with conjugate views
        - test_variant_consistency_eager/jit
        - gradcheck, gradgradcheck
        - test_conj_views (that only run for `torch.cfloat` dtype)
 
NOTE: functions like `empty_like`, `zero_like`, `randn_like`, `clone` don't propagate the conjugate bit.

Follow up work:
1. conjugate view RFC
2. Add neg bit to re-enable view operation on conjugated tensors
3. Update linalg functions to call into specialized functions that fast path with the hermitian operation.

Differential Revision: [D28227315](https://our.internmc.facebook.com/intern/diff/D28227315)

[ghstack-poisoned]
anjali411 added a commit that referenced this pull request May 11, 2021
Based off of ezyang (#44799) and bdhirsh (#43702) 's prototype:  

Here's a summary of the changes in this PR:
This PR adds a new dispatch key called Conjugate. This enables us to make conjugate operation a view and leverage the specialized library functions that fast path with the hermitian operation (conj + transpose). 

1. Conjugate operation will now return a view with conj bit (1) for complex tensors and returns self for non-complex tensors as before. This also means `torch.view_as_real` will no longer be a view on conjugated complex tensors and is hence disabled. To fill the gap, we have added `torch.view_as_real_physical` which would return the real tensor agnostic of the conjugate bit on the input complex tensor. The information about conjugation on the old tensor can be obtained by calling `.is_conj()` on the new tensor.
2. NEW API: 
    a) `.conj()` -- now returning a view.
    b) `.conj_physical()` -- does the physical conjugate operation. If the conj bit for input was set, you'd get `self.clone()`, else you'll get a new tensor with conjugated value in its memory.
    c) `.conj_physical_()`, and `out=` variant
    d) `.resolve_conj()`  -- materializes the conjugation. returns self if the conj bit is unset, else returns a new tensor with conjugated values and conj bit set to 0.
    e)  `view_as_real_physical` -- as described in (1), it's functionally same as `view_as_real`, just that it doesn't error out on conjugated tensors.
    g) `view_as_real` -- existing function, but now errors out on conjugated tensors.
3. Conjugate Fallback
    a) Vast majority of PyTorch functions would currently use this fallback when they are called on a conjugated tensor.
    b) This fallback is well equipped to handle the following cases:
        - functional operation e.g., `torch.sin(input)`
        - Mutable inputs and in-place operations e.g., `tensor.add_(2)`
        - out-of-place operation e.g., `torch.sin(input, out=out)`
        - Tensorlist input args
        - NOTE: Meta tensors don't work with conjugate fallback.
4. Autograd
    a) `resolve_conj()` is an identity function w.r.t. autograd
    b)  Everything else works as expected.
5. Testing: 
    a) All method_tests run with conjugate view tensors.
    b) OpInfo tests that run with conjugate views
        - test_variant_consistency_eager/jit
        - gradcheck, gradgradcheck
        - test_conj_views (that only run for `torch.cfloat` dtype)
 
NOTE: functions like `empty_like`, `zero_like`, `randn_like`, `clone` don't propagate the conjugate bit.

Follow up work:
1. conjugate view RFC
2. Add neg bit to re-enable view operation on conjugated tensors
3. Update linalg functions to call into specialized functions that fast path with the hermitian operation.

Differential Revision: [D28227315](https://our.internmc.facebook.com/intern/diff/D28227315)

[ghstack-poisoned]
anjali411 added a commit that referenced this pull request Jun 3, 2021
Based off of ezyang (#44799) and bdhirsh (#43702) 's prototype:  

Here's a summary of the changes in this PR:
This PR adds a new dispatch key called Conjugate. This enables us to make conjugate operation a view and leverage the specialized library functions that fast path with the hermitian operation (conj + transpose). 

1. Conjugate operation will now return a view with conj bit (1) for complex tensors and returns self for non-complex tensors as before. This also means `torch.view_as_real` will no longer be a view on conjugated complex tensors and is hence disabled. To fill the gap, we have added `torch.view_as_real_physical` which would return the real tensor agnostic of the conjugate bit on the input complex tensor. The information about conjugation on the old tensor can be obtained by calling `.is_conj()` on the new tensor.
2. NEW API: 
    a) `.conj()` -- now returning a view.
    b) `.conj_physical()` -- does the physical conjugate operation. If the conj bit for input was set, you'd get `self.clone()`, else you'll get a new tensor with conjugated value in its memory.
    c) `.conj_physical_()`, and `out=` variant
    d) `.resolve_conj()`  -- materializes the conjugation. returns self if the conj bit is unset, else returns a new tensor with conjugated values and conj bit set to 0.
    e)  `view_as_real_physical` -- as described in (1), it's functionally same as `view_as_real`, just that it doesn't error out on conjugated tensors.
    g) `view_as_real` -- existing function, but now errors out on conjugated tensors.
3. Conjugate Fallback
    a) Vast majority of PyTorch functions would currently use this fallback when they are called on a conjugated tensor.
    b) This fallback is well equipped to handle the following cases:
        - functional operation e.g., `torch.sin(input)`
        - Mutable inputs and in-place operations e.g., `tensor.add_(2)`
        - out-of-place operation e.g., `torch.sin(input, out=out)`
        - Tensorlist input args
        - NOTE: Meta tensors don't work with conjugate fallback.
4. Autograd
    a) `resolve_conj()` is an identity function w.r.t. autograd
    b)  Everything else works as expected.
5. Testing: 
    a) All method_tests run with conjugate view tensors.
    b) OpInfo tests that run with conjugate views
        - test_variant_consistency_eager/jit
        - gradcheck, gradgradcheck
        - test_conj_views (that only run for `torch.cfloat` dtype)
 
NOTE: functions like `empty_like`, `zero_like`, `randn_like`, `clone` don't propagate the conjugate bit.

Follow up work:
1. conjugate view RFC
2. Add neg bit to re-enable view operation on conjugated tensors
3. Update linalg functions to call into specialized functions that fast path with the hermitian operation.

Differential Revision: [D28227315](https://our.internmc.facebook.com/intern/diff/D28227315)

[ghstack-poisoned]
anjali411 added a commit that referenced this pull request Jun 3, 2021
Based off of ezyang (#44799) and bdhirsh (#43702) 's prototype:  

Here's a summary of the changes in this PR:
This PR adds a new dispatch key called Conjugate. This enables us to make conjugate operation a view and leverage the specialized library functions that fast path with the hermitian operation (conj + transpose). 

1. Conjugate operation will now return a view with conj bit (1) for complex tensors and returns self for non-complex tensors as before. This also means `torch.view_as_real` will no longer be a view on conjugated complex tensors and is hence disabled. To fill the gap, we have added `torch.view_as_real_physical` which would return the real tensor agnostic of the conjugate bit on the input complex tensor. The information about conjugation on the old tensor can be obtained by calling `.is_conj()` on the new tensor.
2. NEW API: 
    a) `.conj()` -- now returning a view.
    b) `.conj_physical()` -- does the physical conjugate operation. If the conj bit for input was set, you'd get `self.clone()`, else you'll get a new tensor with conjugated value in its memory.
    c) `.conj_physical_()`, and `out=` variant
    d) `.resolve_conj()`  -- materializes the conjugation. returns self if the conj bit is unset, else returns a new tensor with conjugated values and conj bit set to 0.
    e)  `view_as_real_physical` -- as described in (1), it's functionally same as `view_as_real`, just that it doesn't error out on conjugated tensors.
    g) `view_as_real` -- existing function, but now errors out on conjugated tensors.
3. Conjugate Fallback
    a) Vast majority of PyTorch functions would currently use this fallback when they are called on a conjugated tensor.
    b) This fallback is well equipped to handle the following cases:
        - functional operation e.g., `torch.sin(input)`
        - Mutable inputs and in-place operations e.g., `tensor.add_(2)`
        - out-of-place operation e.g., `torch.sin(input, out=out)`
        - Tensorlist input args
        - NOTE: Meta tensors don't work with conjugate fallback.
4. Autograd
    a) `resolve_conj()` is an identity function w.r.t. autograd
    b)  Everything else works as expected.
5. Testing: 
    a) All method_tests run with conjugate view tensors.
    b) OpInfo tests that run with conjugate views
        - test_variant_consistency_eager/jit
        - gradcheck, gradgradcheck
        - test_conj_views (that only run for `torch.cfloat` dtype)
 
NOTE: functions like `empty_like`, `zero_like`, `randn_like`, `clone` don't propagate the conjugate bit.

Follow up work:
1. conjugate view RFC
2. Add neg bit to re-enable view operation on conjugated tensors
3. Update linalg functions to call into specialized functions that fast path with the hermitian operation.

Differential Revision: [D28227315](https://our.internmc.facebook.com/intern/diff/D28227315)

[ghstack-poisoned]
anjali411 added a commit that referenced this pull request Jun 3, 2021
Based off of ezyang (#44799) and bdhirsh (#43702) 's prototype:  

Here's a summary of the changes in this PR:
This PR adds a new dispatch key called Conjugate. This enables us to make conjugate operation a view and leverage the specialized library functions that fast path with the hermitian operation (conj + transpose). 

1. Conjugate operation will now return a view with conj bit (1) for complex tensors and returns self for non-complex tensors as before. This also means `torch.view_as_real` will no longer be a view on conjugated complex tensors and is hence disabled. To fill the gap, we have added `torch.view_as_real_physical` which would return the real tensor agnostic of the conjugate bit on the input complex tensor. The information about conjugation on the old tensor can be obtained by calling `.is_conj()` on the new tensor.
2. NEW API: 
    a) `.conj()` -- now returning a view.
    b) `.conj_physical()` -- does the physical conjugate operation. If the conj bit for input was set, you'd get `self.clone()`, else you'll get a new tensor with conjugated value in its memory.
    c) `.conj_physical_()`, and `out=` variant
    d) `.resolve_conj()`  -- materializes the conjugation. returns self if the conj bit is unset, else returns a new tensor with conjugated values and conj bit set to 0.
    e)  `view_as_real_physical` -- as described in (1), it's functionally same as `view_as_real`, just that it doesn't error out on conjugated tensors.
    g) `view_as_real` -- existing function, but now errors out on conjugated tensors.
3. Conjugate Fallback
    a) Vast majority of PyTorch functions would currently use this fallback when they are called on a conjugated tensor.
    b) This fallback is well equipped to handle the following cases:
        - functional operation e.g., `torch.sin(input)`
        - Mutable inputs and in-place operations e.g., `tensor.add_(2)`
        - out-of-place operation e.g., `torch.sin(input, out=out)`
        - Tensorlist input args
        - NOTE: Meta tensors don't work with conjugate fallback.
4. Autograd
    a) `resolve_conj()` is an identity function w.r.t. autograd
    b)  Everything else works as expected.
5. Testing: 
    a) All method_tests run with conjugate view tensors.
    b) OpInfo tests that run with conjugate views
        - test_variant_consistency_eager/jit
        - gradcheck, gradgradcheck
        - test_conj_views (that only run for `torch.cfloat` dtype)
 
NOTE: functions like `empty_like`, `zero_like`, `randn_like`, `clone` don't propagate the conjugate bit.

Follow up work:
1. conjugate view RFC
2. Add neg bit to re-enable view operation on conjugated tensors
3. Update linalg functions to call into specialized functions that fast path with the hermitian operation.

Differential Revision: [D28227315](https://our.internmc.facebook.com/intern/diff/D28227315)

[ghstack-poisoned]
anjali411 added a commit that referenced this pull request Jun 4, 2021
Based off of ezyang (#44799) and bdhirsh (#43702) 's prototype:  

Here's a summary of the changes in this PR:
This PR adds a new dispatch key called Conjugate. This enables us to make conjugate operation a view and leverage the specialized library functions that fast path with the hermitian operation (conj + transpose). 

1. Conjugate operation will now return a view with conj bit (1) for complex tensors and returns self for non-complex tensors as before. This also means `torch.view_as_real` will no longer be a view on conjugated complex tensors and is hence disabled. To fill the gap, we have added `torch.view_as_real_physical` which would return the real tensor agnostic of the conjugate bit on the input complex tensor. The information about conjugation on the old tensor can be obtained by calling `.is_conj()` on the new tensor.
2. NEW API: 
    a) `.conj()` -- now returning a view.
    b) `.conj_physical()` -- does the physical conjugate operation. If the conj bit for input was set, you'd get `self.clone()`, else you'll get a new tensor with conjugated value in its memory.
    c) `.conj_physical_()`, and `out=` variant
    d) `.resolve_conj()`  -- materializes the conjugation. returns self if the conj bit is unset, else returns a new tensor with conjugated values and conj bit set to 0.
    e)  `view_as_real_physical` -- as described in (1), it's functionally same as `view_as_real`, just that it doesn't error out on conjugated tensors.
    g) `view_as_real` -- existing function, but now errors out on conjugated tensors.
3. Conjugate Fallback
    a) Vast majority of PyTorch functions would currently use this fallback when they are called on a conjugated tensor.
    b) This fallback is well equipped to handle the following cases:
        - functional operation e.g., `torch.sin(input)`
        - Mutable inputs and in-place operations e.g., `tensor.add_(2)`
        - out-of-place operation e.g., `torch.sin(input, out=out)`
        - Tensorlist input args
        - NOTE: Meta tensors don't work with conjugate fallback.
4. Autograd
    a) `resolve_conj()` is an identity function w.r.t. autograd
    b)  Everything else works as expected.
5. Testing: 
    a) All method_tests run with conjugate view tensors.
    b) OpInfo tests that run with conjugate views
        - test_variant_consistency_eager/jit
        - gradcheck, gradgradcheck
        - test_conj_views (that only run for `torch.cfloat` dtype)
 
NOTE: functions like `empty_like`, `zero_like`, `randn_like`, `clone` don't propagate the conjugate bit.

Follow up work:
1. conjugate view RFC
2. Add neg bit to re-enable view operation on conjugated tensors
3. Update linalg functions to call into specialized functions that fast path with the hermitian operation.

Differential Revision: [D28227315](https://our.internmc.facebook.com/intern/diff/D28227315)

[ghstack-poisoned]
anjali411 added a commit that referenced this pull request Jun 4, 2021
Based off of ezyang (#44799) and bdhirsh (#43702) 's prototype:  

Here's a summary of the changes in this PR:
This PR adds a new dispatch key called Conjugate. This enables us to make conjugate operation a view and leverage the specialized library functions that fast path with the hermitian operation (conj + transpose). 

1. Conjugate operation will now return a view with conj bit (1) for complex tensors and returns self for non-complex tensors as before. This also means `torch.view_as_real` will no longer be a view on conjugated complex tensors and is hence disabled. To fill the gap, we have added `torch.view_as_real_physical` which would return the real tensor agnostic of the conjugate bit on the input complex tensor. The information about conjugation on the old tensor can be obtained by calling `.is_conj()` on the new tensor.
2. NEW API: 
    a) `.conj()` -- now returning a view.
    b) `.conj_physical()` -- does the physical conjugate operation. If the conj bit for input was set, you'd get `self.clone()`, else you'll get a new tensor with conjugated value in its memory.
    c) `.conj_physical_()`, and `out=` variant
    d) `.resolve_conj()`  -- materializes the conjugation. returns self if the conj bit is unset, else returns a new tensor with conjugated values and conj bit set to 0.
    e)  `view_as_real_physical` -- as described in (1), it's functionally same as `view_as_real`, just that it doesn't error out on conjugated tensors.
    g) `view_as_real` -- existing function, but now errors out on conjugated tensors.
3. Conjugate Fallback
    a) Vast majority of PyTorch functions would currently use this fallback when they are called on a conjugated tensor.
    b) This fallback is well equipped to handle the following cases:
        - functional operation e.g., `torch.sin(input)`
        - Mutable inputs and in-place operations e.g., `tensor.add_(2)`
        - out-of-place operation e.g., `torch.sin(input, out=out)`
        - Tensorlist input args
        - NOTE: Meta tensors don't work with conjugate fallback.
4. Autograd
    a) `resolve_conj()` is an identity function w.r.t. autograd
    b)  Everything else works as expected.
5. Testing: 
    a) All method_tests run with conjugate view tensors.
    b) OpInfo tests that run with conjugate views
        - test_variant_consistency_eager/jit
        - gradcheck, gradgradcheck
        - test_conj_views (that only run for `torch.cfloat` dtype)
 
NOTE: functions like `empty_like`, `zero_like`, `randn_like`, `clone` don't propagate the conjugate bit.

Follow up work:
1. conjugate view RFC
2. Add neg bit to re-enable view operation on conjugated tensors
3. Update linalg functions to call into specialized functions that fast path with the hermitian operation.

Differential Revision: [D28227315](https://our.internmc.facebook.com/intern/diff/D28227315)

[ghstack-poisoned]
facebook-github-bot pushed a commit that referenced this pull request Jun 4, 2021
Summary:
Pull Request resolved: #54987

Based off of ezyang (#44799) and bdhirsh (#43702) 's prototype:

Here's a summary of the changes in this PR:
This PR adds a new dispatch key called Conjugate. This enables us to make conjugate operation a view and leverage the specialized library functions that fast path with the hermitian operation (conj + transpose).

1. Conjugate operation will now return a view with conj bit (1) for complex tensors and returns self for non-complex tensors as before. This also means `torch.view_as_real` will no longer be a view on conjugated complex tensors and is hence disabled. To fill the gap, we have added `torch.view_as_real_physical` which would return the real tensor agnostic of the conjugate bit on the input complex tensor. The information about conjugation on the old tensor can be obtained by calling `.is_conj()` on the new tensor.
2. NEW API:
    a) `.conj()` -- now returning a view.
    b) `.conj_physical()` -- does the physical conjugate operation. If the conj bit for input was set, you'd get `self.clone()`, else you'll get a new tensor with conjugated value in its memory.
    c) `.conj_physical_()`, and `out=` variant
    d) `.resolve_conj()`  -- materializes the conjugation. returns self if the conj bit is unset, else returns a new tensor with conjugated values and conj bit set to 0.
    e) `.resolve_conj_()` in-place version of (d)
    f) `view_as_real_physical` -- as described in (1), it's functionally same as `view_as_real`, just that it doesn't error out on conjugated tensors.
    g) `view_as_real` -- existing function, but now errors out on conjugated tensors.
3. Conjugate Fallback
    a) Vast majority of PyTorch functions would currently use this fallback when they are called on a conjugated tensor.
    b) This fallback is well equipped to handle the following cases:
        - functional operation e.g., `torch.sin(input)`
        - Mutable inputs and in-place operations e.g., `tensor.add_(2)`
        - out-of-place operation e.g., `torch.sin(input, out=out)`
        - Tensorlist input args
        - NOTE: Meta tensors don't work with conjugate fallback.
4. Autograd
    a) `resolve_conj()` is an identity function w.r.t. autograd
    b) Everything else works as expected.
5. Testing:
    a) All method_tests run with conjugate view tensors.
    b) OpInfo tests that run with conjugate views
        - test_variant_consistency_eager/jit
        - gradcheck, gradgradcheck
        - test_conj_views (that only run for `torch.cfloat` dtype)

NOTE: functions like `empty_like`, `zero_like`, `randn_like`, `clone` don't propagate the conjugate bit.

Follow up work:
1. conjugate view RFC
2. Add neg bit to re-enable view operation on conjugated tensors
3. Update linalg functions to call into specialized functions that fast path with the hermitian operation.

Test Plan: Imported from OSS

Reviewed By: VitalyFedyunin

Differential Revision: D28227315

Pulled By: anjali411

fbshipit-source-id: acab9402b9d6a970c6d512809b627a290c8def5f
deniskokarev pushed a commit to deniskokarev/pytorch that referenced this pull request Jun 9, 2021
Summary:
Pull Request resolved: pytorch#54987

Based off of ezyang (pytorch#44799) and bdhirsh (pytorch#43702) 's prototype:

Here's a summary of the changes in this PR:
This PR adds a new dispatch key called Conjugate. This enables us to make conjugate operation a view and leverage the specialized library functions that fast path with the hermitian operation (conj + transpose).

1. Conjugate operation will now return a view with conj bit (1) for complex tensors and returns self for non-complex tensors as before. This also means `torch.view_as_real` will no longer be a view on conjugated complex tensors and is hence disabled. To fill the gap, we have added `torch.view_as_real_physical` which would return the real tensor agnostic of the conjugate bit on the input complex tensor. The information about conjugation on the old tensor can be obtained by calling `.is_conj()` on the new tensor.
2. NEW API:
    a) `.conj()` -- now returning a view.
    b) `.conj_physical()` -- does the physical conjugate operation. If the conj bit for input was set, you'd get `self.clone()`, else you'll get a new tensor with conjugated value in its memory.
    c) `.conj_physical_()`, and `out=` variant
    d) `.resolve_conj()`  -- materializes the conjugation. returns self if the conj bit is unset, else returns a new tensor with conjugated values and conj bit set to 0.
    e) `.resolve_conj_()` in-place version of (d)
    f) `view_as_real_physical` -- as described in (1), it's functionally same as `view_as_real`, just that it doesn't error out on conjugated tensors.
    g) `view_as_real` -- existing function, but now errors out on conjugated tensors.
3. Conjugate Fallback
    a) Vast majority of PyTorch functions would currently use this fallback when they are called on a conjugated tensor.
    b) This fallback is well equipped to handle the following cases:
        - functional operation e.g., `torch.sin(input)`
        - Mutable inputs and in-place operations e.g., `tensor.add_(2)`
        - out-of-place operation e.g., `torch.sin(input, out=out)`
        - Tensorlist input args
        - NOTE: Meta tensors don't work with conjugate fallback.
4. Autograd
    a) `resolve_conj()` is an identity function w.r.t. autograd
    b) Everything else works as expected.
5. Testing:
    a) All method_tests run with conjugate view tensors.
    b) OpInfo tests that run with conjugate views
        - test_variant_consistency_eager/jit
        - gradcheck, gradgradcheck
        - test_conj_views (that only run for `torch.cfloat` dtype)

NOTE: functions like `empty_like`, `zero_like`, `randn_like`, `clone` don't propagate the conjugate bit.

Follow up work:
1. conjugate view RFC
2. Add neg bit to re-enable view operation on conjugated tensors
3. Update linalg functions to call into specialized functions that fast path with the hermitian operation.

Test Plan: Imported from OSS

Reviewed By: VitalyFedyunin

Differential Revision: D28227315

Pulled By: anjali411

fbshipit-source-id: acab9402b9d6a970c6d512809b627a290c8def5f
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants