-
Notifications
You must be signed in to change notification settings - Fork 25.7k
first cut of adding backend fallback for conjugation #43702
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
[ghstack-poisoned]
aten/src/ATen/ConjugateFallback.cpp
Outdated
| // Is it reaonsable to set the memory on the stack in place like this? | ||
| // I don't see other examples of fallbacks doing this, | ||
| // so instead I'm pushing the new conjugated vector onto the stack | ||
| auto conjugated_tensor = tensor.conj(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
See the comment here- I don't think I Have a clear understanding how a given fallback in the dispatch workflow is supposed to modify the stack.
My understanding is that the stack represents the call stack, i.e. the top of it contains the arguments to the kernel function that will eventually get called after all other dispatch fallback functions.
Should my conjugate fallback be modifying those argument in place? Pushing new arguments onto the stack?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The stack is a mutable data structure, and the invariants are that when you call some downstream function, the top N arguments of the stack (however many arguments the function takes) are the arguments. On return, the invariant is that the input arguments had been popped from the stack, and the outputs are pushed onto the stack.
before stack: [arg1, arg2, arg3] # NB: maybe it's the other way, I can never remember
call function
after stack: [ret1, ret2]
As the stack is mutable, an easy way you can modify the stack is simply by overwriting the direct entries with the new, explicitly conjugated (not using the bit) version. I'm not sure if the stack API has a direct way of expressing this, but stacks are just vectors so you can implement it by hand yourself.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah, I looked over your implementation and it doesn't work. I can give you more hints about how to fix it, but perhaps the invariant described above is enough?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
BTW, the use of conj() here directly is also unlikely to work, because you still have the conjugate bit on the tensor and you'll redispatch here. You need an operation like conj_materialize() which will take a tensor whose conjugate is bit, and returns the conjugated version without the conjugate bit set. You could probably do this by first allocating a new tensor ala conj_view with the conjugate bit turned off, and then calling conj()
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
thanks- I fixed the issues described (I think), but I'm still hitting issues. It might clarify my understanding of the stack in relation to a dispatch chain:
A dispatch "chain" is a sequence of function calls, consisting of (possibly several) "pre-processing" functions followed by a single backend kernel function. The set of functions called are in order of the priority of the dispatch keys, and are determined by:
- the set of dispatch keys in the current context (in the tensor arguments, and local/global context)
(2) the mappings from from a given op/dispatch key to kernel functions or "fallback" functions. In the case of this conjugate function, we want the function to be called for all tensors with the Conjugate dispatch key set, regardless of the op)
A reasonable dispatch chain in the case of my change: Conjugate -> Add(cpu).
Now, I'm currently thinking of the stack as being passed along to each of the functions in the dispatch chain, such that the state of the stack at the end of the first function is the same as the state at the beginning of the second function.
E.g. suppose: we have:
Tensor a = ... Tensor b = ... Tensor c = a.conj_view() + b
Invoking the "add" function should eventually invoke the dispatcher, first calling conjugateFallback() and then the add() kernel. I'd expect the stack at the beginning of the call to conjugateFallback to look like:
[a (tensor, conj bit set); b (tensor, conj bit not set)]
and after the call to look like:
[a (conj bit NOT set, but the imaginary part negated); b (unchanged, conj bit not set)]
(What we'd want the stack to look like at the beginning of the call to the add() kernel)
Which I believe is what the above is now doing. I'm currently hitting an issue however, where the arguments currently at the top of the stack are not being recognized as tensors (their tag is set to c10::IValue::Tag::None). Later, the call to the add() kernel function fails with "There were no tensor arguments to this function (e.g., you passed an empty list of Tensors)", which makes me think I corrupted the stack somewhere in my conjugate fallback.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This description looks accurate.
|
|
||
| TEST(BackendFallbackTest, ConjugateTest) { | ||
| auto m = MAKE_TORCH_LIBRARY_IMPL(_, TESTING_ONLY_GenericMode); | ||
| m.fallback(torch::CppFunction::makeFromBoxedFunction<&conjugateFallback>()); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Couldn't get this line to compile :(
/home/hirsheybar/pytorch/build/../aten/src/ATen/core/boxing/KernelFunction_impl.h:22: undefined reference to at::conjugateFallback(c10::OperatorHandle const&, std::vector<c10::IValue, std::allocatorc10::IValue >*)'
collect2: error: ld returned 1 exit status`
conjugateFallback() is defined in ConjugateFallback.h, included at the top of this file. It's in the at:: namespace, which we're using here as well. I also tried replacing in the function header torch::jit::Stack* with std::vector<c10::IValue, std::allocator<c10::IValue> >* as the error suggests which didn't fix (I'm assuming Stack is aliased to that type).
aten/src/ATen/native/UnaryOps.cpp
Outdated
| Tensor conj_view(const Tensor& self) { | ||
| Tensor self_; | ||
| if (self.is_quantized()) { | ||
| auto impl = c10::make_intrusive<QTensorImpl>( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I pulled the quantized logic from alias_with_sizes_and_strides, but I'm not actually sure if I need it.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We don't have quantized complex, so probably not :)
aten/src/ATen/templates/TensorBody.h
Outdated
| return !at::impl::variable_excluded_from_dispatch(); | ||
| } | ||
|
|
||
| // TODO: do I need to add this method directly on the Tensor to implement conj_view() in UnaryOps? |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I added a getter/setter for conjugate() on the TensorImpl class, but I needed to add the getter to the Tensor class as well in order to flip the state correctly in the conj_view() function, since it deals with Tensors and not TensorImpls. This doesn't feel right though- let me know if I'm missing something.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is OK for a prototype. In the real implementation we'll probably have to bikeshed under what name we should expose this functionality, and probably it should get a native functions schema.
💊 CI failures summary and remediationsAs of commit d31d4fe (more details on the Dr. CI page):
🕵️ 10 new failures recognized by patternsThe following CI failures do not appear to be due to upstream breakages:
|
c10/core/TensorImpl.h
Outdated
|
|
||
|
|
||
| /** | ||
| * Whether or not the imaginary part of the tensor should be conjugated |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Technically, you only conjugate a complex number; conjugation involves negating the imaginary part of the tensor.
c10/core/TensorImpl.h
Outdated
| void set_conjugate(bool value) { | ||
| conjugate_ = value; | ||
| if (conjugate_) | ||
| key_set_ = key_set_.add(DispatchKey::Named); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Named? ;)
|
|
||
| // This fallback effectively takes all tensors in the stack | ||
| // with their conjugate bit set, and runs conjugation on them | ||
| void conjugateFallback(const c10::OperatorHandle& op, torch::jit::Stack* stack); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
add a CAFFE2_API to the beginning of this function, that will fix your linker error. (We hide symbols by default to make our behavior like Windows, and that means you need to explicitly annotate functions which will be accessible from other files).
However, even better would be to directly register the fallback in ConjugateFallback.cpp using a TORCH_LIBRARY_IMPL block, rather than in the test; after all, you want the registration to be available for all code!
| } | ||
|
|
||
| TEST(BackendFallbackTest, ConjugateTest) { | ||
| auto m = MAKE_TORCH_LIBRARY_IMPL(_, TESTING_ONLY_GenericMode); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You're registering to the wrong key, should be Conjugate
| auto m = MAKE_TORCH_LIBRARY_IMPL(_, TESTING_ONLY_GenericMode); | ||
| m.fallback(torch::CppFunction::makeFromBoxedFunction<&conjugateFallback>()); | ||
|
|
||
| c10::impl::IncludeDispatchKeyGuard guard(DispatchKey::Conjugate); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You don't need this. Conjugate key is on the tensor, so that's how you will get to the conjugate implementation. Check the dispatcher slides again about this :)
[ghstack-poisoned]
…_LIBRARY_IMPL to register my conjugate fallback (plus other minor fixes) I'm still hitting issues- tracing through my failing test, I see two issues: - my call to 'a.conj_view() + b' fails with "There were no tensor arguments to this function...", which sounds vaguely to me like I corrupted the call stack in some way in my conjugateFallback function call - Tracing through a call to conjugateFallback(), I see that none of the ivalue arguments on the stack are actually tensors (ivalue.isTensor() == false), so the function just pops off all of the arguments and pushes then back on. first cut of adding backend fallback for conjugation ghstack-source-id: 4fc04d5 Pull Request resolved: #43702
aten/src/ATen/ConjugateFallback.cpp
Outdated
| void conjugateFallback(const c10::OperatorHandle& op, torch::jit::Stack* stack) { | ||
| // Unwrap all arguments | ||
| const auto num_arguments = op.schema().arguments().size(); | ||
| const auto arguments = torch::jit::last(stack, num_arguments); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
torch::jit::last returns a non-owning reference to the arguments. When you pop them later in line 14, this reference is no longer valid; therefore your memory corruption.
[ghstack-poisoned]
[ghstack-poisoned]
fixed previous issue and generally cleaned up comments/code. I had to define conj_materialize() inside of the ConjugateFallback file. Before that I tried adding it directly to native_functions.yml (which we probably don't want anyway), which caused the dispatcher to infinite loop (repeatedly alternating between conj_materialize() and conjugateFallback()) Responding to the first set of PR feedback. Main change = using TORCH_LIBRARY_IMPL to register my conjugate fallback (plus other minor fixes) I'm still hitting issues- tracing through my failing test, I see two issues: - my call to 'a.conj_view() + b' fails with "There were no tensor arguments to this function...", which sounds vaguely to me like I corrupted the call stack in some way in my conjugateFallback function call - Tracing through a call to conjugateFallback(), I see that none of the ivalue arguments on the stack are actually tensors (ivalue.isTensor() == false), so the function just pops off all of the arguments and pushes then back on. first cut of adding backend fallback for conjugation ghstack-source-id: 0f01918 Pull Request resolved: #43702
|
fyi ci failures are real |
aten/src/ATen/ConjugateFallback.cpp
Outdated
|
|
||
| Tensor conj_materialize(const Tensor& self) { | ||
| // this function assumes that the tensor input has it's conjugate bit set | ||
| self.set_conjugate(false); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is not actually sound because it will affect all other references to the tensor; you need to allocate a new tensor (sharing storage/sizes/strides) and then set the conjugate bit to false and then pass it to conj. Or maybe we come up with an alternate way to call self.conj() that bypasses the conjugate flag and interprets the tensor in the "original" way (similar to how AutoNonVariableTypeMode works).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
One reason I'm a little hesitant about a NoConjMode is that it makes the invariants on the system a little more difficult to understand. With your current system, it's easy to see how conjugate bits relate to the Conjugate dispatch key: before running the Conjugate dispatch key, you may have tensors with conjugate bit set; after the dispatch key, nothing has conjugate bit set. With NoConjMode, you may selectively "let tensors with conjugate bit" through to the backend, and so invariants get a little harder to see.
That being said, it's a little hard for me to understand how to structure things in a reasonable way without letting the conjugate bit pass through to the backend. For example, matrix multiply with fused conjugation should definitely be implemented in the CPU/CUDA backend, so we definitely have to pass the conjugate bit all the way to matmul. So perhaps the invariant isn't worth preserving.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It sounds like you're saying that we can lump backend kernels into two groups:
- conjugation-aware: a (relatively small number of) backend/op pairs that can take in a conjugate flag, and do conjugation along with the backend operation (multiplication on CPU/CUDA) in a performant way (something we want to take advantage of). We DO want to pass the conjugate bit pass through to the backend in this case.
- conjugation-unaware: kernels that require conjugation to be performed separately (and part of the questions is whether we want to explicit do the conjugation in the kernel, or move the logic somewhere else like I'm doing in this PR). It would be (maybe) preferable to not pass the conjugate bit through to the backend, to make the system easier to reason about.
Would it be reasonable to make the dispatcher aware of the specific op/backend pairs that can take in a conjugation flag, and know to only call out to my conjugate fallback if (a) the conjugate key is set on a tensor input, and (b) the particular op/backend pair currently being called is not in that first category?
I guess that would make this conjugateFallback kernel not really a "fallback" anymore. More like a fallback "for everything except this whitelist of op/backend pairs", which also might become pretty hard to reason about.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It would be (maybe) preferable to not pass the conjugate bit through to the backend, to make the system easier to reason about.
Yes. The only trouble is if you "accidentally" pass along a tensor with the conjugate bit to the backend, and if there is no sanity checking in the back, you will end up with hilarious bugs where the result is the conjugate of the expected result.
Would it be reasonable to make the dispatcher aware of the specific op/backend pairs that can take in a conjugation flag, and know to only call out to my conjugate fallback if (a) the conjugate key is set on a tensor input, and (b) the particular op/backend pair currently being called is not in that first category?
Yes. In fact, we can do this by registering fallthrough functions to operations that support conjugation bit directly. It is probably not worth splitting this into a per-backend concept; it is probably better being an all or nothing deal.
More like a fallback "for everything except this whitelist of op/backend pairs", which also might become pretty hard to reason about.
I agree. Which is why @zdevito's suggestion to just call materialize_conj manually in all the kernels is appealing. After all, you already had to do some coding to make complex work, this is just one more thing.
|
This prototype looks great. We'll talk about what we should do next with this in our sync shortly. cc @anjali411 |
aten/src/ATen/ConjugateFallback.cpp
Outdated
| auto conjugated_tensor = conj_materialize(tensor); | ||
| torch::jit::push(stack, conjugated_tensor); | ||
| } else { | ||
| torch::jit::push(stack, ivalue); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: You can improve the performance of this code by being more careful about moving ivalues on and off the argument stack. This will help you avoid unnecessary reference count bumps. So you would move an argument off the arguments list, and the move it back onto the stack when you push it.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm going to try to fix this: instead of popping all arguments off the stack and pushing them all back on (most of which are unchanged except for the conjugated tensor), I'm just going to selectively overwrite the memory in the stack (vector) for the tensors that need to be conjugated. Let me know if that's what you had in mind.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That will be even better :)
| use_c10_dispatcher: full | ||
| variants: function, method | ||
|
|
||
| - func: conj_view(Tensor self) -> Tensor |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@bdhirsh asked why we named this separately as conj_view. Based on discussion with @anjali411 and co at #43270 we are seriously considering making this implementation conj (and so the actual kernel that converts the conjugation out-of-place will need to be given a different name)
c10/core/TensorImpl.h
Outdated
|
|
||
| // Whether or not to conjugate the imaginary part of the tensor | ||
| // TODO: is there a reason that we need this bool? | ||
| // Doesn't the existence of the conjugate dispatch key give us everything we need? |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
From in person discussion: No, it's probably fine to omit this boolean here since it's in the dispatch key set.
|
Some notes about what else we'd need to do to land this:
|
[ghstack-poisoned]
fixed previous issue and generally cleaned up comments/code. I had to define conj_materialize() inside of the ConjugateFallback file. Before that I tried adding it directly to native_functions.yml (which we probably don't want anyway), which caused the dispatcher to infinite loop (repeatedly alternating between conj_materialize() and conjugateFallback()) Responding to the first set of PR feedback. Main change = using TORCH_LIBRARY_IMPL to register my conjugate fallback (plus other minor fixes) I'm still hitting issues- tracing through my failing test, I see two issues: - my call to 'a.conj_view() + b' fails with "There were no tensor arguments to this function...", which sounds vaguely to me like I corrupted the call stack in some way in my conjugateFallback function call - Tracing through a call to conjugateFallback(), I see that none of the ivalue arguments on the stack are actually tensors (ivalue.isTensor() == false), so the function just pops off all of the arguments and pushes then back on. first cut of adding backend fallback for conjugation ghstack-source-id: 162a424 Pull Request resolved: #43702 cr feedback: (a) trying an exclude guard instead of removing the conjugate flag directly on the tensor, (b) make conjugate fallback more efficient, (c) removed unnecessary bool conjugate_ plus some other cleanup
|
Some more updates having worked on this on Tuesday:
|
Based off of ezyang (#44799) and bdhirsh (#43702) 's prototype: Here's a summary of the changes in this PR: This PR adds a new dispatch key called Conjugate. This enables us to make conjugate operation a view and leverage the specialized library functions that fast path with the hermitian operation (conj + transpose). 1. Conjugate operation will now return a view with conj bit (1) for complex tensors and returns self for non-complex tensors as before. This also means `torch.view_as_real` will no longer be a view on conjugated complex tensors and is hence disabled. To fill the gap, we have added `torch.view_as_real_physical` which would return the real tensor agnostic of the conjugate bit on the input complex tensor. The information about conjugation on the old tensor can be obtained by calling `.is_conj()` on the new tensor. 2. NEW API: a) `.conj()` -- now returning a view. b) `.conj_physical()` -- does the physical conjugate operation. If the conj bit for input was set, you'd get `self.clone()`, else you'll get a new tensor with conjugated value in its memory. c) `.conj_physical_()`, and `out=` variant d) `.resolve_conj()` -- materializes the conjugation. returns self if the conj bit is unset, else returns a new tensor with conjugated values and conj bit set to 0. e) `.resolve_conj_()` in-place version of (d) f) `view_as_real_physical` -- as described in (1), it's functionally same as `view_as_real`, just that it doesn't error out on conjugated tensors. g) `view_as_real` -- existing function, but now errors out on conjugated tensors. 3. Conjugate Fallback a) Vast majority of PyTorch functions would currently use this fallback when they are called on a conjugated tensor. b) This fallback is well equipped to handle the following cases: - functional operation e.g., `torch.sin(input)` - Mutable inputs and in-place operations e.g., `tensor.add_(2)` - out-of-place operation e.g., `torch.sin(input, out=out)` - Tensorlist input args - NOTE: Meta tensors don't work with conjugate fallback. 4. Autograd a) `resolve_conj()` is an identity function w.r.t. autograd b) Everything else works as expected. 5. Testing: a) All method_tests run with conjugate view tensors. b) OpInfo tests that run with conjugate views - test_variant_consistency_eager/jit - gradcheck, gradgradcheck - test_conj_views (that only run for `torch.cfloat` dtype) TODO: 1. conjugate view RFC [ghstack-poisoned]
Based off of ezyang (#44799) and bdhirsh (#43702) 's prototype: Here's a summary of the changes in this PR: This PR adds a new dispatch key called Conjugate. This enables us to make conjugate operation a view and leverage the specialized library functions that fast path with the hermitian operation (conj + transpose). 1. Conjugate operation will now return a view with conj bit (1) for complex tensors and returns self for non-complex tensors as before. This also means `torch.view_as_real` will no longer be a view on conjugated complex tensors and is hence disabled. To fill the gap, we have added `torch.view_as_real_physical` which would return the real tensor agnostic of the conjugate bit on the input complex tensor. The information about conjugation on the old tensor can be obtained by calling `.is_conj()` on the new tensor. 2. NEW API: a) `.conj()` -- now returning a view. b) `.conj_physical()` -- does the physical conjugate operation. If the conj bit for input was set, you'd get `self.clone()`, else you'll get a new tensor with conjugated value in its memory. c) `.conj_physical_()`, and `out=` variant d) `.resolve_conj()` -- materializes the conjugation. returns self if the conj bit is unset, else returns a new tensor with conjugated values and conj bit set to 0. e) `.resolve_conj_()` in-place version of (d) f) `view_as_real_physical` -- as described in (1), it's functionally same as `view_as_real`, just that it doesn't error out on conjugated tensors. g) `view_as_real` -- existing function, but now errors out on conjugated tensors. 3. Conjugate Fallback a) Vast majority of PyTorch functions would currently use this fallback when they are called on a conjugated tensor. b) This fallback is well equipped to handle the following cases: - functional operation e.g., `torch.sin(input)` - Mutable inputs and in-place operations e.g., `tensor.add_(2)` - out-of-place operation e.g., `torch.sin(input, out=out)` - Tensorlist input args - NOTE: Meta tensors don't work with conjugate fallback. 4. Autograd a) `resolve_conj()` is an identity function w.r.t. autograd b) Everything else works as expected. 5. Testing: a) All method_tests run with conjugate view tensors. b) OpInfo tests that run with conjugate views - test_variant_consistency_eager/jit - gradcheck, gradgradcheck - test_conj_views (that only run for `torch.cfloat` dtype) TODO: 1. conjugate view RFC [ghstack-poisoned]
Based off of ezyang (#44799) and bdhirsh (#43702) 's prototype: Here's a summary of the changes in this PR: This PR adds a new dispatch key called Conjugate. This enables us to make conjugate operation a view and leverage the specialized library functions that fast path with the hermitian operation (conj + transpose). 1. Conjugate operation will now return a view with conj bit (1) for complex tensors and returns self for non-complex tensors as before. This also means `torch.view_as_real` will no longer be a view on conjugated complex tensors and is hence disabled. To fill the gap, we have added `torch.view_as_real_physical` which would return the real tensor agnostic of the conjugate bit on the input complex tensor. The information about conjugation on the old tensor can be obtained by calling `.is_conj()` on the new tensor. 2. NEW API: a) `.conj()` -- now returning a view. b) `.conj_physical()` -- does the physical conjugate operation. If the conj bit for input was set, you'd get `self.clone()`, else you'll get a new tensor with conjugated value in its memory. c) `.conj_physical_()`, and `out=` variant d) `.resolve_conj()` -- materializes the conjugation. returns self if the conj bit is unset, else returns a new tensor with conjugated values and conj bit set to 0. e) `.resolve_conj_()` in-place version of (d) f) `view_as_real_physical` -- as described in (1), it's functionally same as `view_as_real`, just that it doesn't error out on conjugated tensors. g) `view_as_real` -- existing function, but now errors out on conjugated tensors. 3. Conjugate Fallback a) Vast majority of PyTorch functions would currently use this fallback when they are called on a conjugated tensor. b) This fallback is well equipped to handle the following cases: - functional operation e.g., `torch.sin(input)` - Mutable inputs and in-place operations e.g., `tensor.add_(2)` - out-of-place operation e.g., `torch.sin(input, out=out)` - Tensorlist input args - NOTE: Meta tensors don't work with conjugate fallback. 4. Autograd a) `resolve_conj()` is an identity function w.r.t. autograd b) Everything else works as expected. 5. Testing: a) All method_tests run with conjugate view tensors. b) OpInfo tests that run with conjugate views - test_variant_consistency_eager/jit - gradcheck, gradgradcheck - test_conj_views (that only run for `torch.cfloat` dtype) NOTE: functions like `empty_like`, `zero_like`, `randn_like`, `clone` don't propagate the conjugate bit. Follow up work: 1. conjugate view RFC 2. Add neg bit to re-enable view operation on conjugated tensors 3. Update linalg functions to call into specialized functions that fast path with the hermitian operation. [ghstack-poisoned]
Based off of ezyang (#44799) and bdhirsh (#43702) 's prototype: Here's a summary of the changes in this PR: This PR adds a new dispatch key called Conjugate. This enables us to make conjugate operation a view and leverage the specialized library functions that fast path with the hermitian operation (conj + transpose). 1. Conjugate operation will now return a view with conj bit (1) for complex tensors and returns self for non-complex tensors as before. This also means `torch.view_as_real` will no longer be a view on conjugated complex tensors and is hence disabled. To fill the gap, we have added `torch.view_as_real_physical` which would return the real tensor agnostic of the conjugate bit on the input complex tensor. The information about conjugation on the old tensor can be obtained by calling `.is_conj()` on the new tensor. 2. NEW API: a) `.conj()` -- now returning a view. b) `.conj_physical()` -- does the physical conjugate operation. If the conj bit for input was set, you'd get `self.clone()`, else you'll get a new tensor with conjugated value in its memory. c) `.conj_physical_()`, and `out=` variant d) `.resolve_conj()` -- materializes the conjugation. returns self if the conj bit is unset, else returns a new tensor with conjugated values and conj bit set to 0. e) `.resolve_conj_()` in-place version of (d) f) `view_as_real_physical` -- as described in (1), it's functionally same as `view_as_real`, just that it doesn't error out on conjugated tensors. g) `view_as_real` -- existing function, but now errors out on conjugated tensors. 3. Conjugate Fallback a) Vast majority of PyTorch functions would currently use this fallback when they are called on a conjugated tensor. b) This fallback is well equipped to handle the following cases: - functional operation e.g., `torch.sin(input)` - Mutable inputs and in-place operations e.g., `tensor.add_(2)` - out-of-place operation e.g., `torch.sin(input, out=out)` - Tensorlist input args - NOTE: Meta tensors don't work with conjugate fallback. 4. Autograd a) `resolve_conj()` is an identity function w.r.t. autograd b) Everything else works as expected. 5. Testing: a) All method_tests run with conjugate view tensors. b) OpInfo tests that run with conjugate views - test_variant_consistency_eager/jit - gradcheck, gradgradcheck - test_conj_views (that only run for `torch.cfloat` dtype) NOTE: functions like `empty_like`, `zero_like`, `randn_like`, `clone` don't propagate the conjugate bit. Follow up work: 1. conjugate view RFC 2. Add neg bit to re-enable view operation on conjugated tensors 3. Update linalg functions to call into specialized functions that fast path with the hermitian operation. [ghstack-poisoned]
Based off of ezyang (#44799) and bdhirsh (#43702) 's prototype: Here's a summary of the changes in this PR: This PR adds a new dispatch key called Conjugate. This enables us to make conjugate operation a view and leverage the specialized library functions that fast path with the hermitian operation (conj + transpose). 1. Conjugate operation will now return a view with conj bit (1) for complex tensors and returns self for non-complex tensors as before. This also means `torch.view_as_real` will no longer be a view on conjugated complex tensors and is hence disabled. To fill the gap, we have added `torch.view_as_real_physical` which would return the real tensor agnostic of the conjugate bit on the input complex tensor. The information about conjugation on the old tensor can be obtained by calling `.is_conj()` on the new tensor. 2. NEW API: a) `.conj()` -- now returning a view. b) `.conj_physical()` -- does the physical conjugate operation. If the conj bit for input was set, you'd get `self.clone()`, else you'll get a new tensor with conjugated value in its memory. c) `.conj_physical_()`, and `out=` variant d) `.resolve_conj()` -- materializes the conjugation. returns self if the conj bit is unset, else returns a new tensor with conjugated values and conj bit set to 0. e) `.resolve_conj_()` in-place version of (d) f) `view_as_real_physical` -- as described in (1), it's functionally same as `view_as_real`, just that it doesn't error out on conjugated tensors. g) `view_as_real` -- existing function, but now errors out on conjugated tensors. 3. Conjugate Fallback a) Vast majority of PyTorch functions would currently use this fallback when they are called on a conjugated tensor. b) This fallback is well equipped to handle the following cases: - functional operation e.g., `torch.sin(input)` - Mutable inputs and in-place operations e.g., `tensor.add_(2)` - out-of-place operation e.g., `torch.sin(input, out=out)` - Tensorlist input args - NOTE: Meta tensors don't work with conjugate fallback. 4. Autograd a) `resolve_conj()` is an identity function w.r.t. autograd b) Everything else works as expected. 5. Testing: a) All method_tests run with conjugate view tensors. b) OpInfo tests that run with conjugate views - test_variant_consistency_eager/jit - gradcheck, gradgradcheck - test_conj_views (that only run for `torch.cfloat` dtype) NOTE: functions like `empty_like`, `zero_like`, `randn_like`, `clone` don't propagate the conjugate bit. Follow up work: 1. conjugate view RFC 2. Add neg bit to re-enable view operation on conjugated tensors 3. Update linalg functions to call into specialized functions that fast path with the hermitian operation. [ghstack-poisoned]
Based off of ezyang (#44799) and bdhirsh (#43702) 's prototype: Here's a summary of the changes in this PR: This PR adds a new dispatch key called Conjugate. This enables us to make conjugate operation a view and leverage the specialized library functions that fast path with the hermitian operation (conj + transpose). 1. Conjugate operation will now return a view with conj bit (1) for complex tensors and returns self for non-complex tensors as before. This also means `torch.view_as_real` will no longer be a view on conjugated complex tensors and is hence disabled. To fill the gap, we have added `torch.view_as_real_physical` which would return the real tensor agnostic of the conjugate bit on the input complex tensor. The information about conjugation on the old tensor can be obtained by calling `.is_conj()` on the new tensor. 2. NEW API: a) `.conj()` -- now returning a view. b) `.conj_physical()` -- does the physical conjugate operation. If the conj bit for input was set, you'd get `self.clone()`, else you'll get a new tensor with conjugated value in its memory. c) `.conj_physical_()`, and `out=` variant d) `.resolve_conj()` -- materializes the conjugation. returns self if the conj bit is unset, else returns a new tensor with conjugated values and conj bit set to 0. e) `.resolve_conj_()` in-place version of (d) f) `view_as_real_physical` -- as described in (1), it's functionally same as `view_as_real`, just that it doesn't error out on conjugated tensors. g) `view_as_real` -- existing function, but now errors out on conjugated tensors. 3. Conjugate Fallback a) Vast majority of PyTorch functions would currently use this fallback when they are called on a conjugated tensor. b) This fallback is well equipped to handle the following cases: - functional operation e.g., `torch.sin(input)` - Mutable inputs and in-place operations e.g., `tensor.add_(2)` - out-of-place operation e.g., `torch.sin(input, out=out)` - Tensorlist input args - NOTE: Meta tensors don't work with conjugate fallback. 4. Autograd a) `resolve_conj()` is an identity function w.r.t. autograd b) Everything else works as expected. 5. Testing: a) All method_tests run with conjugate view tensors. b) OpInfo tests that run with conjugate views - test_variant_consistency_eager/jit - gradcheck, gradgradcheck - test_conj_views (that only run for `torch.cfloat` dtype) NOTE: functions like `empty_like`, `zero_like`, `randn_like`, `clone` don't propagate the conjugate bit. Follow up work: 1. conjugate view RFC 2. Add neg bit to re-enable view operation on conjugated tensors 3. Update linalg functions to call into specialized functions that fast path with the hermitian operation. [ghstack-poisoned]
Based off of ezyang (#44799) and bdhirsh (#43702) 's prototype: Here's a summary of the changes in this PR: This PR adds a new dispatch key called Conjugate. This enables us to make conjugate operation a view and leverage the specialized library functions that fast path with the hermitian operation (conj + transpose). 1. Conjugate operation will now return a view with conj bit (1) for complex tensors and returns self for non-complex tensors as before. This also means `torch.view_as_real` will no longer be a view on conjugated complex tensors and is hence disabled. To fill the gap, we have added `torch.view_as_real_physical` which would return the real tensor agnostic of the conjugate bit on the input complex tensor. The information about conjugation on the old tensor can be obtained by calling `.is_conj()` on the new tensor. 2. NEW API: a) `.conj()` -- now returning a view. b) `.conj_physical()` -- does the physical conjugate operation. If the conj bit for input was set, you'd get `self.clone()`, else you'll get a new tensor with conjugated value in its memory. c) `.conj_physical_()`, and `out=` variant d) `.resolve_conj()` -- materializes the conjugation. returns self if the conj bit is unset, else returns a new tensor with conjugated values and conj bit set to 0. e) `.resolve_conj_()` in-place version of (d) f) `view_as_real_physical` -- as described in (1), it's functionally same as `view_as_real`, just that it doesn't error out on conjugated tensors. g) `view_as_real` -- existing function, but now errors out on conjugated tensors. 3. Conjugate Fallback a) Vast majority of PyTorch functions would currently use this fallback when they are called on a conjugated tensor. b) This fallback is well equipped to handle the following cases: - functional operation e.g., `torch.sin(input)` - Mutable inputs and in-place operations e.g., `tensor.add_(2)` - out-of-place operation e.g., `torch.sin(input, out=out)` - Tensorlist input args - NOTE: Meta tensors don't work with conjugate fallback. 4. Autograd a) `resolve_conj()` is an identity function w.r.t. autograd b) Everything else works as expected. 5. Testing: a) All method_tests run with conjugate view tensors. b) OpInfo tests that run with conjugate views - test_variant_consistency_eager/jit - gradcheck, gradgradcheck - test_conj_views (that only run for `torch.cfloat` dtype) NOTE: functions like `empty_like`, `zero_like`, `randn_like`, `clone` don't propagate the conjugate bit. Follow up work: 1. conjugate view RFC 2. Add neg bit to re-enable view operation on conjugated tensors 3. Update linalg functions to call into specialized functions that fast path with the hermitian operation. [ghstack-poisoned]
Based off of ezyang (#44799) and bdhirsh (#43702) 's prototype: Here's a summary of the changes in this PR: This PR adds a new dispatch key called Conjugate. This enables us to make conjugate operation a view and leverage the specialized library functions that fast path with the hermitian operation (conj + transpose). 1. Conjugate operation will now return a view with conj bit (1) for complex tensors and returns self for non-complex tensors as before. This also means `torch.view_as_real` will no longer be a view on conjugated complex tensors and is hence disabled. To fill the gap, we have added `torch.view_as_real_physical` which would return the real tensor agnostic of the conjugate bit on the input complex tensor. The information about conjugation on the old tensor can be obtained by calling `.is_conj()` on the new tensor. 2. NEW API: a) `.conj()` -- now returning a view. b) `.conj_physical()` -- does the physical conjugate operation. If the conj bit for input was set, you'd get `self.clone()`, else you'll get a new tensor with conjugated value in its memory. c) `.conj_physical_()`, and `out=` variant d) `.resolve_conj()` -- materializes the conjugation. returns self if the conj bit is unset, else returns a new tensor with conjugated values and conj bit set to 0. e) `.resolve_conj_()` in-place version of (d) f) `view_as_real_physical` -- as described in (1), it's functionally same as `view_as_real`, just that it doesn't error out on conjugated tensors. g) `view_as_real` -- existing function, but now errors out on conjugated tensors. 3. Conjugate Fallback a) Vast majority of PyTorch functions would currently use this fallback when they are called on a conjugated tensor. b) This fallback is well equipped to handle the following cases: - functional operation e.g., `torch.sin(input)` - Mutable inputs and in-place operations e.g., `tensor.add_(2)` - out-of-place operation e.g., `torch.sin(input, out=out)` - Tensorlist input args - NOTE: Meta tensors don't work with conjugate fallback. 4. Autograd a) `resolve_conj()` is an identity function w.r.t. autograd b) Everything else works as expected. 5. Testing: a) All method_tests run with conjugate view tensors. b) OpInfo tests that run with conjugate views - test_variant_consistency_eager/jit - gradcheck, gradgradcheck - test_conj_views (that only run for `torch.cfloat` dtype) NOTE: functions like `empty_like`, `zero_like`, `randn_like`, `clone` don't propagate the conjugate bit. Follow up work: 1. conjugate view RFC 2. Add neg bit to re-enable view operation on conjugated tensors 3. Update linalg functions to call into specialized functions that fast path with the hermitian operation. [ghstack-poisoned]
Based off of ezyang (#44799) and bdhirsh (#43702) 's prototype: Here's a summary of the changes in this PR: This PR adds a new dispatch key called Conjugate. This enables us to make conjugate operation a view and leverage the specialized library functions that fast path with the hermitian operation (conj + transpose). 1. Conjugate operation will now return a view with conj bit (1) for complex tensors and returns self for non-complex tensors as before. This also means `torch.view_as_real` will no longer be a view on conjugated complex tensors and is hence disabled. To fill the gap, we have added `torch.view_as_real_physical` which would return the real tensor agnostic of the conjugate bit on the input complex tensor. The information about conjugation on the old tensor can be obtained by calling `.is_conj()` on the new tensor. 2. NEW API: a) `.conj()` -- now returning a view. b) `.conj_physical()` -- does the physical conjugate operation. If the conj bit for input was set, you'd get `self.clone()`, else you'll get a new tensor with conjugated value in its memory. c) `.conj_physical_()`, and `out=` variant d) `.resolve_conj()` -- materializes the conjugation. returns self if the conj bit is unset, else returns a new tensor with conjugated values and conj bit set to 0. e) `.resolve_conj_()` in-place version of (d) f) `view_as_real_physical` -- as described in (1), it's functionally same as `view_as_real`, just that it doesn't error out on conjugated tensors. g) `view_as_real` -- existing function, but now errors out on conjugated tensors. 3. Conjugate Fallback a) Vast majority of PyTorch functions would currently use this fallback when they are called on a conjugated tensor. b) This fallback is well equipped to handle the following cases: - functional operation e.g., `torch.sin(input)` - Mutable inputs and in-place operations e.g., `tensor.add_(2)` - out-of-place operation e.g., `torch.sin(input, out=out)` - Tensorlist input args - NOTE: Meta tensors don't work with conjugate fallback. 4. Autograd a) `resolve_conj()` is an identity function w.r.t. autograd b) Everything else works as expected. 5. Testing: a) All method_tests run with conjugate view tensors. b) OpInfo tests that run with conjugate views - test_variant_consistency_eager/jit - gradcheck, gradgradcheck - test_conj_views (that only run for `torch.cfloat` dtype) NOTE: functions like `empty_like`, `zero_like`, `randn_like`, `clone` don't propagate the conjugate bit. Follow up work: 1. conjugate view RFC 2. Add neg bit to re-enable view operation on conjugated tensors 3. Update linalg functions to call into specialized functions that fast path with the hermitian operation. [ghstack-poisoned]
Based off of ezyang (#44799) and bdhirsh (#43702) 's prototype: Here's a summary of the changes in this PR: This PR adds a new dispatch key called Conjugate. This enables us to make conjugate operation a view and leverage the specialized library functions that fast path with the hermitian operation (conj + transpose). 1. Conjugate operation will now return a view with conj bit (1) for complex tensors and returns self for non-complex tensors as before. This also means `torch.view_as_real` will no longer be a view on conjugated complex tensors and is hence disabled. To fill the gap, we have added `torch.view_as_real_physical` which would return the real tensor agnostic of the conjugate bit on the input complex tensor. The information about conjugation on the old tensor can be obtained by calling `.is_conj()` on the new tensor. 2. NEW API: a) `.conj()` -- now returning a view. b) `.conj_physical()` -- does the physical conjugate operation. If the conj bit for input was set, you'd get `self.clone()`, else you'll get a new tensor with conjugated value in its memory. c) `.conj_physical_()`, and `out=` variant d) `.resolve_conj()` -- materializes the conjugation. returns self if the conj bit is unset, else returns a new tensor with conjugated values and conj bit set to 0. e) `.resolve_conj_()` in-place version of (d) f) `view_as_real_physical` -- as described in (1), it's functionally same as `view_as_real`, just that it doesn't error out on conjugated tensors. g) `view_as_real` -- existing function, but now errors out on conjugated tensors. 3. Conjugate Fallback a) Vast majority of PyTorch functions would currently use this fallback when they are called on a conjugated tensor. b) This fallback is well equipped to handle the following cases: - functional operation e.g., `torch.sin(input)` - Mutable inputs and in-place operations e.g., `tensor.add_(2)` - out-of-place operation e.g., `torch.sin(input, out=out)` - Tensorlist input args - NOTE: Meta tensors don't work with conjugate fallback. 4. Autograd a) `resolve_conj()` is an identity function w.r.t. autograd b) Everything else works as expected. 5. Testing: a) All method_tests run with conjugate view tensors. b) OpInfo tests that run with conjugate views - test_variant_consistency_eager/jit - gradcheck, gradgradcheck - test_conj_views (that only run for `torch.cfloat` dtype) NOTE: functions like `empty_like`, `zero_like`, `randn_like`, `clone` don't propagate the conjugate bit. Follow up work: 1. conjugate view RFC 2. Add neg bit to re-enable view operation on conjugated tensors 3. Update linalg functions to call into specialized functions that fast path with the hermitian operation. Differential Revision: [D28227315](https://our.internmc.facebook.com/intern/diff/D28227315) [ghstack-poisoned]
Based off of ezyang (#44799) and bdhirsh (#43702) 's prototype: Here's a summary of the changes in this PR: This PR adds a new dispatch key called Conjugate. This enables us to make conjugate operation a view and leverage the specialized library functions that fast path with the hermitian operation (conj + transpose). 1. Conjugate operation will now return a view with conj bit (1) for complex tensors and returns self for non-complex tensors as before. This also means `torch.view_as_real` will no longer be a view on conjugated complex tensors and is hence disabled. To fill the gap, we have added `torch.view_as_real_physical` which would return the real tensor agnostic of the conjugate bit on the input complex tensor. The information about conjugation on the old tensor can be obtained by calling `.is_conj()` on the new tensor. 2. NEW API: a) `.conj()` -- now returning a view. b) `.conj_physical()` -- does the physical conjugate operation. If the conj bit for input was set, you'd get `self.clone()`, else you'll get a new tensor with conjugated value in its memory. c) `.conj_physical_()`, and `out=` variant d) `.resolve_conj()` -- materializes the conjugation. returns self if the conj bit is unset, else returns a new tensor with conjugated values and conj bit set to 0. e) `view_as_real_physical` -- as described in (1), it's functionally same as `view_as_real`, just that it doesn't error out on conjugated tensors. g) `view_as_real` -- existing function, but now errors out on conjugated tensors. 3. Conjugate Fallback a) Vast majority of PyTorch functions would currently use this fallback when they are called on a conjugated tensor. b) This fallback is well equipped to handle the following cases: - functional operation e.g., `torch.sin(input)` - Mutable inputs and in-place operations e.g., `tensor.add_(2)` - out-of-place operation e.g., `torch.sin(input, out=out)` - Tensorlist input args - NOTE: Meta tensors don't work with conjugate fallback. 4. Autograd a) `resolve_conj()` is an identity function w.r.t. autograd b) Everything else works as expected. 5. Testing: a) All method_tests run with conjugate view tensors. b) OpInfo tests that run with conjugate views - test_variant_consistency_eager/jit - gradcheck, gradgradcheck - test_conj_views (that only run for `torch.cfloat` dtype) NOTE: functions like `empty_like`, `zero_like`, `randn_like`, `clone` don't propagate the conjugate bit. Follow up work: 1. conjugate view RFC 2. Add neg bit to re-enable view operation on conjugated tensors 3. Update linalg functions to call into specialized functions that fast path with the hermitian operation. Differential Revision: [D28227315](https://our.internmc.facebook.com/intern/diff/D28227315) [ghstack-poisoned]
Based off of ezyang (#44799) and bdhirsh (#43702) 's prototype: Here's a summary of the changes in this PR: This PR adds a new dispatch key called Conjugate. This enables us to make conjugate operation a view and leverage the specialized library functions that fast path with the hermitian operation (conj + transpose). 1. Conjugate operation will now return a view with conj bit (1) for complex tensors and returns self for non-complex tensors as before. This also means `torch.view_as_real` will no longer be a view on conjugated complex tensors and is hence disabled. To fill the gap, we have added `torch.view_as_real_physical` which would return the real tensor agnostic of the conjugate bit on the input complex tensor. The information about conjugation on the old tensor can be obtained by calling `.is_conj()` on the new tensor. 2. NEW API: a) `.conj()` -- now returning a view. b) `.conj_physical()` -- does the physical conjugate operation. If the conj bit for input was set, you'd get `self.clone()`, else you'll get a new tensor with conjugated value in its memory. c) `.conj_physical_()`, and `out=` variant d) `.resolve_conj()` -- materializes the conjugation. returns self if the conj bit is unset, else returns a new tensor with conjugated values and conj bit set to 0. e) `view_as_real_physical` -- as described in (1), it's functionally same as `view_as_real`, just that it doesn't error out on conjugated tensors. g) `view_as_real` -- existing function, but now errors out on conjugated tensors. 3. Conjugate Fallback a) Vast majority of PyTorch functions would currently use this fallback when they are called on a conjugated tensor. b) This fallback is well equipped to handle the following cases: - functional operation e.g., `torch.sin(input)` - Mutable inputs and in-place operations e.g., `tensor.add_(2)` - out-of-place operation e.g., `torch.sin(input, out=out)` - Tensorlist input args - NOTE: Meta tensors don't work with conjugate fallback. 4. Autograd a) `resolve_conj()` is an identity function w.r.t. autograd b) Everything else works as expected. 5. Testing: a) All method_tests run with conjugate view tensors. b) OpInfo tests that run with conjugate views - test_variant_consistency_eager/jit - gradcheck, gradgradcheck - test_conj_views (that only run for `torch.cfloat` dtype) NOTE: functions like `empty_like`, `zero_like`, `randn_like`, `clone` don't propagate the conjugate bit. Follow up work: 1. conjugate view RFC 2. Add neg bit to re-enable view operation on conjugated tensors 3. Update linalg functions to call into specialized functions that fast path with the hermitian operation. Differential Revision: [D28227315](https://our.internmc.facebook.com/intern/diff/D28227315) [ghstack-poisoned]
Based off of ezyang (#44799) and bdhirsh (#43702) 's prototype: Here's a summary of the changes in this PR: This PR adds a new dispatch key called Conjugate. This enables us to make conjugate operation a view and leverage the specialized library functions that fast path with the hermitian operation (conj + transpose). 1. Conjugate operation will now return a view with conj bit (1) for complex tensors and returns self for non-complex tensors as before. This also means `torch.view_as_real` will no longer be a view on conjugated complex tensors and is hence disabled. To fill the gap, we have added `torch.view_as_real_physical` which would return the real tensor agnostic of the conjugate bit on the input complex tensor. The information about conjugation on the old tensor can be obtained by calling `.is_conj()` on the new tensor. 2. NEW API: a) `.conj()` -- now returning a view. b) `.conj_physical()` -- does the physical conjugate operation. If the conj bit for input was set, you'd get `self.clone()`, else you'll get a new tensor with conjugated value in its memory. c) `.conj_physical_()`, and `out=` variant d) `.resolve_conj()` -- materializes the conjugation. returns self if the conj bit is unset, else returns a new tensor with conjugated values and conj bit set to 0. e) `view_as_real_physical` -- as described in (1), it's functionally same as `view_as_real`, just that it doesn't error out on conjugated tensors. g) `view_as_real` -- existing function, but now errors out on conjugated tensors. 3. Conjugate Fallback a) Vast majority of PyTorch functions would currently use this fallback when they are called on a conjugated tensor. b) This fallback is well equipped to handle the following cases: - functional operation e.g., `torch.sin(input)` - Mutable inputs and in-place operations e.g., `tensor.add_(2)` - out-of-place operation e.g., `torch.sin(input, out=out)` - Tensorlist input args - NOTE: Meta tensors don't work with conjugate fallback. 4. Autograd a) `resolve_conj()` is an identity function w.r.t. autograd b) Everything else works as expected. 5. Testing: a) All method_tests run with conjugate view tensors. b) OpInfo tests that run with conjugate views - test_variant_consistency_eager/jit - gradcheck, gradgradcheck - test_conj_views (that only run for `torch.cfloat` dtype) NOTE: functions like `empty_like`, `zero_like`, `randn_like`, `clone` don't propagate the conjugate bit. Follow up work: 1. conjugate view RFC 2. Add neg bit to re-enable view operation on conjugated tensors 3. Update linalg functions to call into specialized functions that fast path with the hermitian operation. Differential Revision: [D28227315](https://our.internmc.facebook.com/intern/diff/D28227315) [ghstack-poisoned]
Based off of ezyang (#44799) and bdhirsh (#43702) 's prototype: Here's a summary of the changes in this PR: This PR adds a new dispatch key called Conjugate. This enables us to make conjugate operation a view and leverage the specialized library functions that fast path with the hermitian operation (conj + transpose). 1. Conjugate operation will now return a view with conj bit (1) for complex tensors and returns self for non-complex tensors as before. This also means `torch.view_as_real` will no longer be a view on conjugated complex tensors and is hence disabled. To fill the gap, we have added `torch.view_as_real_physical` which would return the real tensor agnostic of the conjugate bit on the input complex tensor. The information about conjugation on the old tensor can be obtained by calling `.is_conj()` on the new tensor. 2. NEW API: a) `.conj()` -- now returning a view. b) `.conj_physical()` -- does the physical conjugate operation. If the conj bit for input was set, you'd get `self.clone()`, else you'll get a new tensor with conjugated value in its memory. c) `.conj_physical_()`, and `out=` variant d) `.resolve_conj()` -- materializes the conjugation. returns self if the conj bit is unset, else returns a new tensor with conjugated values and conj bit set to 0. e) `view_as_real_physical` -- as described in (1), it's functionally same as `view_as_real`, just that it doesn't error out on conjugated tensors. g) `view_as_real` -- existing function, but now errors out on conjugated tensors. 3. Conjugate Fallback a) Vast majority of PyTorch functions would currently use this fallback when they are called on a conjugated tensor. b) This fallback is well equipped to handle the following cases: - functional operation e.g., `torch.sin(input)` - Mutable inputs and in-place operations e.g., `tensor.add_(2)` - out-of-place operation e.g., `torch.sin(input, out=out)` - Tensorlist input args - NOTE: Meta tensors don't work with conjugate fallback. 4. Autograd a) `resolve_conj()` is an identity function w.r.t. autograd b) Everything else works as expected. 5. Testing: a) All method_tests run with conjugate view tensors. b) OpInfo tests that run with conjugate views - test_variant_consistency_eager/jit - gradcheck, gradgradcheck - test_conj_views (that only run for `torch.cfloat` dtype) NOTE: functions like `empty_like`, `zero_like`, `randn_like`, `clone` don't propagate the conjugate bit. Follow up work: 1. conjugate view RFC 2. Add neg bit to re-enable view operation on conjugated tensors 3. Update linalg functions to call into specialized functions that fast path with the hermitian operation. Differential Revision: [D28227315](https://our.internmc.facebook.com/intern/diff/D28227315) [ghstack-poisoned]
Based off of ezyang (#44799) and bdhirsh (#43702) 's prototype: Here's a summary of the changes in this PR: This PR adds a new dispatch key called Conjugate. This enables us to make conjugate operation a view and leverage the specialized library functions that fast path with the hermitian operation (conj + transpose). 1. Conjugate operation will now return a view with conj bit (1) for complex tensors and returns self for non-complex tensors as before. This also means `torch.view_as_real` will no longer be a view on conjugated complex tensors and is hence disabled. To fill the gap, we have added `torch.view_as_real_physical` which would return the real tensor agnostic of the conjugate bit on the input complex tensor. The information about conjugation on the old tensor can be obtained by calling `.is_conj()` on the new tensor. 2. NEW API: a) `.conj()` -- now returning a view. b) `.conj_physical()` -- does the physical conjugate operation. If the conj bit for input was set, you'd get `self.clone()`, else you'll get a new tensor with conjugated value in its memory. c) `.conj_physical_()`, and `out=` variant d) `.resolve_conj()` -- materializes the conjugation. returns self if the conj bit is unset, else returns a new tensor with conjugated values and conj bit set to 0. e) `view_as_real_physical` -- as described in (1), it's functionally same as `view_as_real`, just that it doesn't error out on conjugated tensors. g) `view_as_real` -- existing function, but now errors out on conjugated tensors. 3. Conjugate Fallback a) Vast majority of PyTorch functions would currently use this fallback when they are called on a conjugated tensor. b) This fallback is well equipped to handle the following cases: - functional operation e.g., `torch.sin(input)` - Mutable inputs and in-place operations e.g., `tensor.add_(2)` - out-of-place operation e.g., `torch.sin(input, out=out)` - Tensorlist input args - NOTE: Meta tensors don't work with conjugate fallback. 4. Autograd a) `resolve_conj()` is an identity function w.r.t. autograd b) Everything else works as expected. 5. Testing: a) All method_tests run with conjugate view tensors. b) OpInfo tests that run with conjugate views - test_variant_consistency_eager/jit - gradcheck, gradgradcheck - test_conj_views (that only run for `torch.cfloat` dtype) NOTE: functions like `empty_like`, `zero_like`, `randn_like`, `clone` don't propagate the conjugate bit. Follow up work: 1. conjugate view RFC 2. Add neg bit to re-enable view operation on conjugated tensors 3. Update linalg functions to call into specialized functions that fast path with the hermitian operation. Differential Revision: [D28227315](https://our.internmc.facebook.com/intern/diff/D28227315) [ghstack-poisoned]
Based off of ezyang (#44799) and bdhirsh (#43702) 's prototype: Here's a summary of the changes in this PR: This PR adds a new dispatch key called Conjugate. This enables us to make conjugate operation a view and leverage the specialized library functions that fast path with the hermitian operation (conj + transpose). 1. Conjugate operation will now return a view with conj bit (1) for complex tensors and returns self for non-complex tensors as before. This also means `torch.view_as_real` will no longer be a view on conjugated complex tensors and is hence disabled. To fill the gap, we have added `torch.view_as_real_physical` which would return the real tensor agnostic of the conjugate bit on the input complex tensor. The information about conjugation on the old tensor can be obtained by calling `.is_conj()` on the new tensor. 2. NEW API: a) `.conj()` -- now returning a view. b) `.conj_physical()` -- does the physical conjugate operation. If the conj bit for input was set, you'd get `self.clone()`, else you'll get a new tensor with conjugated value in its memory. c) `.conj_physical_()`, and `out=` variant d) `.resolve_conj()` -- materializes the conjugation. returns self if the conj bit is unset, else returns a new tensor with conjugated values and conj bit set to 0. e) `view_as_real_physical` -- as described in (1), it's functionally same as `view_as_real`, just that it doesn't error out on conjugated tensors. g) `view_as_real` -- existing function, but now errors out on conjugated tensors. 3. Conjugate Fallback a) Vast majority of PyTorch functions would currently use this fallback when they are called on a conjugated tensor. b) This fallback is well equipped to handle the following cases: - functional operation e.g., `torch.sin(input)` - Mutable inputs and in-place operations e.g., `tensor.add_(2)` - out-of-place operation e.g., `torch.sin(input, out=out)` - Tensorlist input args - NOTE: Meta tensors don't work with conjugate fallback. 4. Autograd a) `resolve_conj()` is an identity function w.r.t. autograd b) Everything else works as expected. 5. Testing: a) All method_tests run with conjugate view tensors. b) OpInfo tests that run with conjugate views - test_variant_consistency_eager/jit - gradcheck, gradgradcheck - test_conj_views (that only run for `torch.cfloat` dtype) NOTE: functions like `empty_like`, `zero_like`, `randn_like`, `clone` don't propagate the conjugate bit. Follow up work: 1. conjugate view RFC 2. Add neg bit to re-enable view operation on conjugated tensors 3. Update linalg functions to call into specialized functions that fast path with the hermitian operation. Differential Revision: [D28227315](https://our.internmc.facebook.com/intern/diff/D28227315) [ghstack-poisoned]
Based off of ezyang (#44799) and bdhirsh (#43702) 's prototype: Here's a summary of the changes in this PR: This PR adds a new dispatch key called Conjugate. This enables us to make conjugate operation a view and leverage the specialized library functions that fast path with the hermitian operation (conj + transpose). 1. Conjugate operation will now return a view with conj bit (1) for complex tensors and returns self for non-complex tensors as before. This also means `torch.view_as_real` will no longer be a view on conjugated complex tensors and is hence disabled. To fill the gap, we have added `torch.view_as_real_physical` which would return the real tensor agnostic of the conjugate bit on the input complex tensor. The information about conjugation on the old tensor can be obtained by calling `.is_conj()` on the new tensor. 2. NEW API: a) `.conj()` -- now returning a view. b) `.conj_physical()` -- does the physical conjugate operation. If the conj bit for input was set, you'd get `self.clone()`, else you'll get a new tensor with conjugated value in its memory. c) `.conj_physical_()`, and `out=` variant d) `.resolve_conj()` -- materializes the conjugation. returns self if the conj bit is unset, else returns a new tensor with conjugated values and conj bit set to 0. e) `view_as_real_physical` -- as described in (1), it's functionally same as `view_as_real`, just that it doesn't error out on conjugated tensors. g) `view_as_real` -- existing function, but now errors out on conjugated tensors. 3. Conjugate Fallback a) Vast majority of PyTorch functions would currently use this fallback when they are called on a conjugated tensor. b) This fallback is well equipped to handle the following cases: - functional operation e.g., `torch.sin(input)` - Mutable inputs and in-place operations e.g., `tensor.add_(2)` - out-of-place operation e.g., `torch.sin(input, out=out)` - Tensorlist input args - NOTE: Meta tensors don't work with conjugate fallback. 4. Autograd a) `resolve_conj()` is an identity function w.r.t. autograd b) Everything else works as expected. 5. Testing: a) All method_tests run with conjugate view tensors. b) OpInfo tests that run with conjugate views - test_variant_consistency_eager/jit - gradcheck, gradgradcheck - test_conj_views (that only run for `torch.cfloat` dtype) NOTE: functions like `empty_like`, `zero_like`, `randn_like`, `clone` don't propagate the conjugate bit. Follow up work: 1. conjugate view RFC 2. Add neg bit to re-enable view operation on conjugated tensors 3. Update linalg functions to call into specialized functions that fast path with the hermitian operation. Differential Revision: [D28227315](https://our.internmc.facebook.com/intern/diff/D28227315) [ghstack-poisoned]
Based off of ezyang (#44799) and bdhirsh (#43702) 's prototype: Here's a summary of the changes in this PR: This PR adds a new dispatch key called Conjugate. This enables us to make conjugate operation a view and leverage the specialized library functions that fast path with the hermitian operation (conj + transpose). 1. Conjugate operation will now return a view with conj bit (1) for complex tensors and returns self for non-complex tensors as before. This also means `torch.view_as_real` will no longer be a view on conjugated complex tensors and is hence disabled. To fill the gap, we have added `torch.view_as_real_physical` which would return the real tensor agnostic of the conjugate bit on the input complex tensor. The information about conjugation on the old tensor can be obtained by calling `.is_conj()` on the new tensor. 2. NEW API: a) `.conj()` -- now returning a view. b) `.conj_physical()` -- does the physical conjugate operation. If the conj bit for input was set, you'd get `self.clone()`, else you'll get a new tensor with conjugated value in its memory. c) `.conj_physical_()`, and `out=` variant d) `.resolve_conj()` -- materializes the conjugation. returns self if the conj bit is unset, else returns a new tensor with conjugated values and conj bit set to 0. e) `view_as_real_physical` -- as described in (1), it's functionally same as `view_as_real`, just that it doesn't error out on conjugated tensors. g) `view_as_real` -- existing function, but now errors out on conjugated tensors. 3. Conjugate Fallback a) Vast majority of PyTorch functions would currently use this fallback when they are called on a conjugated tensor. b) This fallback is well equipped to handle the following cases: - functional operation e.g., `torch.sin(input)` - Mutable inputs and in-place operations e.g., `tensor.add_(2)` - out-of-place operation e.g., `torch.sin(input, out=out)` - Tensorlist input args - NOTE: Meta tensors don't work with conjugate fallback. 4. Autograd a) `resolve_conj()` is an identity function w.r.t. autograd b) Everything else works as expected. 5. Testing: a) All method_tests run with conjugate view tensors. b) OpInfo tests that run with conjugate views - test_variant_consistency_eager/jit - gradcheck, gradgradcheck - test_conj_views (that only run for `torch.cfloat` dtype) NOTE: functions like `empty_like`, `zero_like`, `randn_like`, `clone` don't propagate the conjugate bit. Follow up work: 1. conjugate view RFC 2. Add neg bit to re-enable view operation on conjugated tensors 3. Update linalg functions to call into specialized functions that fast path with the hermitian operation. Differential Revision: [D28227315](https://our.internmc.facebook.com/intern/diff/D28227315) [ghstack-poisoned]
Summary: Pull Request resolved: #54987 Based off of ezyang (#44799) and bdhirsh (#43702) 's prototype: Here's a summary of the changes in this PR: This PR adds a new dispatch key called Conjugate. This enables us to make conjugate operation a view and leverage the specialized library functions that fast path with the hermitian operation (conj + transpose). 1. Conjugate operation will now return a view with conj bit (1) for complex tensors and returns self for non-complex tensors as before. This also means `torch.view_as_real` will no longer be a view on conjugated complex tensors and is hence disabled. To fill the gap, we have added `torch.view_as_real_physical` which would return the real tensor agnostic of the conjugate bit on the input complex tensor. The information about conjugation on the old tensor can be obtained by calling `.is_conj()` on the new tensor. 2. NEW API: a) `.conj()` -- now returning a view. b) `.conj_physical()` -- does the physical conjugate operation. If the conj bit for input was set, you'd get `self.clone()`, else you'll get a new tensor with conjugated value in its memory. c) `.conj_physical_()`, and `out=` variant d) `.resolve_conj()` -- materializes the conjugation. returns self if the conj bit is unset, else returns a new tensor with conjugated values and conj bit set to 0. e) `.resolve_conj_()` in-place version of (d) f) `view_as_real_physical` -- as described in (1), it's functionally same as `view_as_real`, just that it doesn't error out on conjugated tensors. g) `view_as_real` -- existing function, but now errors out on conjugated tensors. 3. Conjugate Fallback a) Vast majority of PyTorch functions would currently use this fallback when they are called on a conjugated tensor. b) This fallback is well equipped to handle the following cases: - functional operation e.g., `torch.sin(input)` - Mutable inputs and in-place operations e.g., `tensor.add_(2)` - out-of-place operation e.g., `torch.sin(input, out=out)` - Tensorlist input args - NOTE: Meta tensors don't work with conjugate fallback. 4. Autograd a) `resolve_conj()` is an identity function w.r.t. autograd b) Everything else works as expected. 5. Testing: a) All method_tests run with conjugate view tensors. b) OpInfo tests that run with conjugate views - test_variant_consistency_eager/jit - gradcheck, gradgradcheck - test_conj_views (that only run for `torch.cfloat` dtype) NOTE: functions like `empty_like`, `zero_like`, `randn_like`, `clone` don't propagate the conjugate bit. Follow up work: 1. conjugate view RFC 2. Add neg bit to re-enable view operation on conjugated tensors 3. Update linalg functions to call into specialized functions that fast path with the hermitian operation. Test Plan: Imported from OSS Reviewed By: VitalyFedyunin Differential Revision: D28227315 Pulled By: anjali411 fbshipit-source-id: acab9402b9d6a970c6d512809b627a290c8def5f
Summary: Pull Request resolved: pytorch#54987 Based off of ezyang (pytorch#44799) and bdhirsh (pytorch#43702) 's prototype: Here's a summary of the changes in this PR: This PR adds a new dispatch key called Conjugate. This enables us to make conjugate operation a view and leverage the specialized library functions that fast path with the hermitian operation (conj + transpose). 1. Conjugate operation will now return a view with conj bit (1) for complex tensors and returns self for non-complex tensors as before. This also means `torch.view_as_real` will no longer be a view on conjugated complex tensors and is hence disabled. To fill the gap, we have added `torch.view_as_real_physical` which would return the real tensor agnostic of the conjugate bit on the input complex tensor. The information about conjugation on the old tensor can be obtained by calling `.is_conj()` on the new tensor. 2. NEW API: a) `.conj()` -- now returning a view. b) `.conj_physical()` -- does the physical conjugate operation. If the conj bit for input was set, you'd get `self.clone()`, else you'll get a new tensor with conjugated value in its memory. c) `.conj_physical_()`, and `out=` variant d) `.resolve_conj()` -- materializes the conjugation. returns self if the conj bit is unset, else returns a new tensor with conjugated values and conj bit set to 0. e) `.resolve_conj_()` in-place version of (d) f) `view_as_real_physical` -- as described in (1), it's functionally same as `view_as_real`, just that it doesn't error out on conjugated tensors. g) `view_as_real` -- existing function, but now errors out on conjugated tensors. 3. Conjugate Fallback a) Vast majority of PyTorch functions would currently use this fallback when they are called on a conjugated tensor. b) This fallback is well equipped to handle the following cases: - functional operation e.g., `torch.sin(input)` - Mutable inputs and in-place operations e.g., `tensor.add_(2)` - out-of-place operation e.g., `torch.sin(input, out=out)` - Tensorlist input args - NOTE: Meta tensors don't work with conjugate fallback. 4. Autograd a) `resolve_conj()` is an identity function w.r.t. autograd b) Everything else works as expected. 5. Testing: a) All method_tests run with conjugate view tensors. b) OpInfo tests that run with conjugate views - test_variant_consistency_eager/jit - gradcheck, gradgradcheck - test_conj_views (that only run for `torch.cfloat` dtype) NOTE: functions like `empty_like`, `zero_like`, `randn_like`, `clone` don't propagate the conjugate bit. Follow up work: 1. conjugate view RFC 2. Add neg bit to re-enable view operation on conjugated tensors 3. Update linalg functions to call into specialized functions that fast path with the hermitian operation. Test Plan: Imported from OSS Reviewed By: VitalyFedyunin Differential Revision: D28227315 Pulled By: anjali411 fbshipit-source-id: acab9402b9d6a970c6d512809b627a290c8def5f
Stack from ghstack: