KEMBAR78
Dispatch to Python via __torch_dispatch__ by ezyang · Pull Request #59760 · pytorch/pytorch · GitHub
Skip to content

Conversation

@ezyang
Copy link
Contributor

@ezyang ezyang commented Jun 9, 2021

Stack from ghstack:

See #59049

There are some moving parts to this PR, I'll structure this explanation so the straightforward parts go first, and then the less straightforward parts.

The actual dispatch to Python. The core logic of dispatch to Python lives in concrete_dispatch_fn in torch/csrc/autograd/python_variable.cpp. It takes the input IValue stack, scans all the arguments for Tensor arguments, and defers most of the heavy lifting to handle_torch_function_no_python_arg_parser which actually does all of the logic for calling out to torch dispatch (in particular, this function handles multiple dispatch situations for you). Because we have a different function name than regular __torch_function__ handling, handle_torch_function_no_python_arg_parser is generalized to accept a magic method name to look for when testing if Tensors have custom handling or not. Unlike __torch_function__, by default there is no __torch_dispatch__ on Tensor classes.

Maintaining the Python dispatch key. In order to get to the dispatch to Python logic, we must tag Tensors with the __torch_dispatch__ magic method with the newly added Python dispatch key (separated from PythonFuncTorch to allow for a transitional period while they migrate to this mechanism). We expose a new private property _is_python_dispatch that assists in debugging if a Tensor is participating in Python dispatch or not. We apply the Python dispatch key the first time a PyObject for a Tensor is constructed (THPVariable_NewWithVar), testing if __torch_dispatch__ exists with then newly added check_has_torch_dispatch.

Shallow copy and detach. For the simple examples tested in this PR, most creations of Tensor route through the dispatcher. The exception to this is shallow_copy_and_detach, which bypasses the dispatcher and is used when saving tensors for backwards. When a Tensor is Python dispatch, we override the behavior of shallow_copy_and_detach to instead directly call into __torch_dispatch__ to perform a detach operation (in the same way it would be invoked if you called detach directly). Because this Python call is triggered directly from c10::TensorImpl, it must be indirected through PyInterpreter::detach, which is the general mechanism for dynamic dispatching to the Python interpreter associated with a TensorImpl.

torchdeploy compatibility. The dispatch to Python logic cannot be directly registered to the dispatcher as it is compiled in the Python library, which will get loaded multiple times per torchdeploy interpreter. Thus, we must employ a two phase process. First, we register a fallback inside a non-Python library (aten/src/ATen/core/PythonFallbackKernel.cpp). Its job is to determine the appropriate PyInterpreter to handle the Python dispatch by going through all of the arguments and finding the first argument that has a PyObject/PyInterpreter. With this PyInterpreter, it makes another dynamic dispatch via "dispatch" which will go to the correct torchdeploy interpreter to handle dispatching to actual Python.

Interaction with __torch_function__. I slightly modified the default Tensor.__torch_function__ to test if the returned tensor is already at the desired subclass before rewrapping. This makes it interoperate smoothly with __torch_dispatch__, without requiring users to explicitly disable default torch_function`

Testing. We provide a simple example of a LoggingTensor for testing, which can be used to generate TorchScript-like traces to observe what operations are being called when a Tensor is invoked. Although a LoggingTensor would be better implemented via an is-a relationship rather than a has-a relationship (as is done in the test), we've done it this way to show that arbitrarily complex compositions of tensors inside a tensor work properly.

Known limitations.

  • We haven't adjusted any operator code, so some patterns may not work (as they lose the Python subclass in an unrecoverable way)
  • We don't ever populate kwargs, even when an argument is kwarg-only
  • Remove arguments when they match the defaults (and provide a different mechanism for getting full defaults)

Signed-off-by: Edward Z. Yang ezyang@fb.com

Differential Revision: D29017912

See #59049

Signed-off-by: Edward Z. Yang <ezyang@fb.com>

[ghstack-poisoned]
@facebook-github-bot
Copy link
Contributor

facebook-github-bot commented Jun 9, 2021

💊 CI failures summary and remediations

As of commit ab4b914 (more details on the Dr. CI page and at hud.pytorch.org/pr/59760):


  • 2/2 failures possibly* introduced in this PR
    • 1/2 non-scanned failure(s)

🕵️ 1 new failure recognized by patterns

The following CI failures do not appear to be due to upstream breakages:

See GitHub Actions build Lint / mypy (1/1)

Step: "Run mypy" (full log | diagnosis details | 🔁 rerun)

2021-06-23T15:47:29.8807327Z torch/utils/benchm...ction "median" in typed context [no-untyped-call]
2021-06-23T15:47:14.8792648Z env:
2021-06-23T15:47:14.8793320Z   pythonLocation: /opt/hostedtoolcache/Python/3.8.10/x64
2021-06-23T15:47:14.8794177Z   LD_LIBRARY_PATH: /opt/hostedtoolcache/Python/3.8.10/x64/lib
2021-06-23T15:47:14.8794816Z ##[endgroup]
2021-06-23T15:47:14.8946361Z + for CONFIG in mypy*.ini
2021-06-23T15:47:14.8947569Z + mypy --config=mypy-strict.ini
2021-06-23T15:47:27.8351349Z torch/utils/benchmark/utils/common.py:166:34: error: Call to untyped function "median" in typed context  [no-untyped-call]
2021-06-23T15:47:27.8354272Z torch/utils/benchmark/utils/common.py:168:31: error: Call to untyped function "percentile" in typed context  [no-untyped-call]
2021-06-23T15:47:27.8357265Z torch/utils/benchmark/utils/common.py:169:31: error: Call to untyped function "percentile" in typed context  [no-untyped-call]
2021-06-23T15:47:27.8360331Z torch/utils/benchmark/utils/common.py:281:18: error: Call to untyped function "round" in typed context  [no-untyped-call]
2021-06-23T15:47:29.8807327Z torch/utils/benchmark/utils/timer.py:304:24: error: Call to untyped function "median" in typed context  [no-untyped-call]
2021-06-23T15:47:34.0791788Z Found 5 errors in 2 files (checked 141 source files)
2021-06-23T15:47:34.9540134Z ##[error]Process completed with exit code 1.
2021-06-23T15:47:34.9640433Z Post job cleanup.
2021-06-23T15:47:35.0734594Z [command]/usr/bin/git version
2021-06-23T15:47:35.0809714Z git version 2.32.0
2021-06-23T15:47:35.0848096Z [command]/usr/bin/git config --local --name-only --get-regexp core\.sshCommand
2021-06-23T15:47:35.0906023Z [command]/usr/bin/git submodule foreach --recursive git config --local --name-only --get-regexp 'core\.sshCommand' && git config --local --unset-all 'core.sshCommand' || :
2021-06-23T15:47:35.1201986Z [command]/usr/bin/git config --local --name-only --get-regexp http\.https\:\/\/github\.com\/\.extraheader
2021-06-23T15:47:35.1238447Z http.https://github.com/.extraheader
2021-06-23T15:47:35.1256286Z [command]/usr/bin/git config --local --unset-all http.https://github.com/.extraheader

This comment was automatically generated by Dr. CI (expand for details).Follow this link to opt-out of these comments for your Pull Requests.

Please report bugs/suggestions to the (internal) Dr. CI Users group.

Click here to manually regenerate this comment.

See #59049

Signed-off-by: Edward Z. Yang <ezyang@fb.com>

[ghstack-poisoned]
@ezyang ezyang requested review from Chillee, bdhirsh and zou3519 June 9, 2021 23:39
See #59049

There are some moving parts to this PR, I'll structure this explanation so the straightforward parts go first, and then the less straightforward parts.

**The actual dispatch to Python.** The core logic of dispatch to Python lives in `concrete_dispatch_fn` in `torch/csrc/autograd/python_variable.cpp`. It takes the input IValue stack, scans all the arguments for Tensor arguments, and defers most of the heavy lifting to `handle_torch_function_no_python_arg_parser` which actually does all of the logic for calling out to torch dispatch (in particular, this function handles multiple dispatch situations for you). Because we have a different function name than regular `__torch_function__` handling, `handle_torch_function_no_python_arg_parser` is generalized to accept a magic method name to look for when testing if Tensors have custom handling or not. Unlike `__torch_function__`, by default there is no `__torch_dispatch__` on Tensor classes. 

**Maintaining the Python dispatch key.** In order to get to the dispatch to Python logic, we must tag Tensors with the `__torch_dispatch__` magic method with the newly added Python dispatch key (separated from PythonFuncTorch to allow for a transitional period while they migrate to this mechanism). We expose a new private property `_is_python_dispatch` that assists in debugging if a Tensor is participating in Python dispatch or not. We apply the Python dispatch key the first time a PyObject for a Tensor is constructed (THPVariable_NewWithVar), testing if `__torch_dispatch__` exists with  then newly added `check_has_torch_dispatch`. 

**Shallow copy and detach.** For the simple examples tested in this PR, most creations of Tensor route through the dispatcher. The exception to this is `shallow_copy_and_detach`, which bypasses the dispatcher and is used when saving tensors for backwards. When a Tensor is Python dispatch, we override the behavior of `shallow_copy_and_detach` to instead directly call into `__torch_dispatch__` to perform a `detach` operation (in the same way it would be invoked if you called `detach` directly). Because this Python call is triggered directly from c10::TensorImpl, it must be indirected through `PyInterpreter::detach`, which is the general mechanism for dynamic dispatching to the Python interpreter associated with a TensorImpl.

**torchdeploy compatibility.** The dispatch to Python logic cannot be directly registered to the dispatcher as it is compiled in the Python library, which will get loaded multiple times per torchdeploy interpreter. Thus, we must employ a two phase process. First, we register a fallback inside a non-Python library (aten/src/ATen/core/PythonFallbackKernel.cpp). Its job is to determine the appropriate PyInterpreter to handle the Python dispatch by going through all of the arguments and finding the first argument that has a PyObject/PyInterpreter. With this PyInterpreter, it makes another dynamic dispatch via "dispatch" which will go to the correct torchdeploy interpreter to handle dispatching to actual Python.

**Testing.** We provide a simple example of a LoggingTensor for testing, which can be used to generate TorchScript-like traces to observe what operations are being called when a Tensor is invoked. Although a LoggingTensor would be better implemented via an is-a relationship rather than a has-a relationship (as is done in the test), we've done it this way to show that arbitrarily complex compositions of tensors inside a tensor work properly.

**Known limitations.**

* We haven't adjusted any operator code, so some patterns may not work (as they lose the Python subclass in an unrecoverable way)
* It's not possible yet to redispatch excluding `__torch_dispatch__`
* `__torch_function__` must be explicitly disabled with `_disabled_torch_function_impl` otherwise things don't work quite correctly (in particular, what is being disabled is default subclass preservation behavior.)
* We don't ever populate kwargs, even when an argument is kwarg-only

Signed-off-by: Edward Z. Yang <ezyang@fb.com>

[ghstack-poisoned]
ezyang added a commit that referenced this pull request Jun 10, 2021
See #59049

Signed-off-by: Edward Z. Yang <ezyang@fb.com>

ghstack-source-id: eba5ea8
Pull Request resolved: #59760
@ezyang
Copy link
Contributor Author

ezyang commented Jun 10, 2021

@ezyang has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

// Perform a detach by deferring to the __torch_dispatch__ implementation of
// detach, which will also arrange for the PyObject to get copied in this
// situation
__ubsan_ignore_function__ c10::intrusive_ptr<TensorImpl> detach(const TensorImpl* self) const {
Copy link
Contributor

@zou3519 zou3519 Jun 10, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shallow copy and detach. For the simple examples tested in this PR, most creations of Tensor route through the dispatcher. The exception to this is shallow_copy_and_detach, which bypasses the dispatcher and is used when saving tensors for backwards. When a Tensor is Python dispatch, we override the behavior of shallow_copy_and_detach to instead directly call into torch_dispatch to perform a detach operation (in the same way it would be invoked if you called detach directly).

In some cases, shallow_copy_and_detach is a little more than a detach, it also resets the version counter(?) (see link). Is this a problem?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh alright, I guess we need a different operator for this.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we can make shallow_copy_and_detach = detach + set VC + set metadata change flag ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, so we could also just manually fix things up after the user implementation returns. TBH, I'm not sure what the right thing to do here.

Copy link
Collaborator

@albanD albanD left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That is surprisingly simple! Good job!

bool allow_tensor_metadata_change) const {
if (key_set_.has(DispatchKey::Python)) {
auto r = pyobj_interpreter_.load(std::memory_order_acquire)->detach(this);
if (r) return r;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As mentioned by Richard, we might want to also update the version counter and allow_tensor_metadata_change flag here after the detach.

def __repr__(self):
return f"LoggingTensor({self.elem})"

def __str__(self):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

__str__ is calling into __repr__ by default no? There is no need to duplicate this code?

__torch_function__ = _disabled_torch_function_impl

@classmethod
def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we remove the kwargs here to make sure it is not ambiguous to users as we always pass them as positional arguments ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, if we decide to never pass kwargs (I'm still not 100% clear what we should do here), we should modify the signature here. It's a little annoying because I have to add more conditional swiss cheese.

@bdhirsh
Copy link
Contributor

bdhirsh commented Jun 14, 2021

The exception to this is shallow_copy_and_detach, which bypasses the dispatcher and is used when saving tensors for backwards

Just to be explicit on why that is: what's the reasoning for not making shallow_copy_and_detach a dispatcher-aware op, like everything else? I can imagine two reasons:

(1) perf, since it's called all the time in autograd and we don't want the dispatcher overhead
(2) Up until dispatch-to-python, it's never had to have multiple implementations to dynamically dispatch between. It's backend agnostic, and doesn't have any specific modal-key behavior aside from the new python behavior.

Those seem pretty valid. But the upside to making it dispatcher-aware would be that we wouldn't need to add all of the heavy custom support for shallow_copy_and_detach inside of PyInterpreter, right? (Wouldn't the custom version of the detach() function be fully accounted for by the boxed dispatch() kernel?)

@bdhirsh
Copy link
Contributor

bdhirsh commented Jun 14, 2021

Also, not directly related to this PR, but I was digging through the python_variable.cpp code to try to understand it better, and I see that THPVariable_Wrap() assumes a PyInterpreter status of MAYBE_UNINITIALIZED when there's no interpreter on the tensor object, which means we need to do a CES to set it on initialization:

status = c10::impl::PyInterpreterStatus::MAYBE_UNINITIALIZED;

Are we forced to pessimistically make that assumption, because we don't know if we're running in a multiple-interpreter environment? It looks like we're forced to do the CES in pretty standard pytorch code:

>>> import torch
a = torch.ones(2)
NewWithVar. status=MAYBE_UNINITIALIZED
>>> b = torch.ones(2)
NewWithVar. status=MAYBE_UNINITIALIZED
>>> c = a + b
NewWithVar. status=MAYBE_UNINITIALIZED
>>>

That seems like a potentially perf improvement opportunity, since it looks like THPVariable_Wrap() is used all the time. Are we able to detect when, e.g., torch deploy isn't running, or we only have one interpreter, and throw in a DEFINITELY_UNINITIALIZED?

@ezyang
Copy link
Contributor Author

ezyang commented Jun 15, 2021

That seems like a potentially perf improvement opportunity, since it looks like THPVariable_Wrap() is used all the time. Are we able to detect when, e.g., torch deploy isn't running, or we only have one interpreter, and throw in a DEFINITELY_UNINITIALIZED?

I agree there is a perf improvement opportunity here. Actually, it's somewhat related to @swolchok's recent #59419 : if we get an ExclusivelyOwned tensor, we can know that it is DEFINITELY_UNINITIALIZED. But an easier fix is to make a special THPVariable_Wrap variant for factory function Python bindings for when we know we just created the tensor.

See #59049

There are some moving parts to this PR, I'll structure this explanation so the straightforward parts go first, and then the less straightforward parts.

**The actual dispatch to Python.** The core logic of dispatch to Python lives in `concrete_dispatch_fn` in `torch/csrc/autograd/python_variable.cpp`. It takes the input IValue stack, scans all the arguments for Tensor arguments, and defers most of the heavy lifting to `handle_torch_function_no_python_arg_parser` which actually does all of the logic for calling out to torch dispatch (in particular, this function handles multiple dispatch situations for you). Because we have a different function name than regular `__torch_function__` handling, `handle_torch_function_no_python_arg_parser` is generalized to accept a magic method name to look for when testing if Tensors have custom handling or not. Unlike `__torch_function__`, by default there is no `__torch_dispatch__` on Tensor classes. 

**Maintaining the Python dispatch key.** In order to get to the dispatch to Python logic, we must tag Tensors with the `__torch_dispatch__` magic method with the newly added Python dispatch key (separated from PythonFuncTorch to allow for a transitional period while they migrate to this mechanism). We expose a new private property `_is_python_dispatch` that assists in debugging if a Tensor is participating in Python dispatch or not. We apply the Python dispatch key the first time a PyObject for a Tensor is constructed (THPVariable_NewWithVar), testing if `__torch_dispatch__` exists with  then newly added `check_has_torch_dispatch`. 

**Shallow copy and detach.** For the simple examples tested in this PR, most creations of Tensor route through the dispatcher. The exception to this is `shallow_copy_and_detach`, which bypasses the dispatcher and is used when saving tensors for backwards. When a Tensor is Python dispatch, we override the behavior of `shallow_copy_and_detach` to instead directly call into `__torch_dispatch__` to perform a `detach` operation (in the same way it would be invoked if you called `detach` directly). Because this Python call is triggered directly from c10::TensorImpl, it must be indirected through `PyInterpreter::detach`, which is the general mechanism for dynamic dispatching to the Python interpreter associated with a TensorImpl.

**torchdeploy compatibility.** The dispatch to Python logic cannot be directly registered to the dispatcher as it is compiled in the Python library, which will get loaded multiple times per torchdeploy interpreter. Thus, we must employ a two phase process. First, we register a fallback inside a non-Python library (aten/src/ATen/core/PythonFallbackKernel.cpp). Its job is to determine the appropriate PyInterpreter to handle the Python dispatch by going through all of the arguments and finding the first argument that has a PyObject/PyInterpreter. With this PyInterpreter, it makes another dynamic dispatch via "dispatch" which will go to the correct torchdeploy interpreter to handle dispatching to actual Python.

**Testing.** We provide a simple example of a LoggingTensor for testing, which can be used to generate TorchScript-like traces to observe what operations are being called when a Tensor is invoked. Although a LoggingTensor would be better implemented via an is-a relationship rather than a has-a relationship (as is done in the test), we've done it this way to show that arbitrarily complex compositions of tensors inside a tensor work properly.

**Known limitations.**

* We haven't adjusted any operator code, so some patterns may not work (as they lose the Python subclass in an unrecoverable way)
* It's not possible yet to redispatch excluding `__torch_dispatch__`
* `__torch_function__` must be explicitly disabled with `_disabled_torch_function_impl` otherwise things don't work quite correctly (in particular, what is being disabled is default subclass preservation behavior.)
* We don't ever populate kwargs, even when an argument is kwarg-only

Signed-off-by: Edward Z. Yang <ezyang@fb.com>

Differential Revision: [D29017912](https://our.internmc.facebook.com/intern/diff/D29017912)

[ghstack-poisoned]
@ezyang
Copy link
Contributor Author

ezyang commented Jun 21, 2021

This is ready for final review!

See #59049

There are some moving parts to this PR, I'll structure this explanation so the straightforward parts go first, and then the less straightforward parts.

**The actual dispatch to Python.** The core logic of dispatch to Python lives in `concrete_dispatch_fn` in `torch/csrc/autograd/python_variable.cpp`. It takes the input IValue stack, scans all the arguments for Tensor arguments, and defers most of the heavy lifting to `handle_torch_function_no_python_arg_parser` which actually does all of the logic for calling out to torch dispatch (in particular, this function handles multiple dispatch situations for you). Because we have a different function name than regular `__torch_function__` handling, `handle_torch_function_no_python_arg_parser` is generalized to accept a magic method name to look for when testing if Tensors have custom handling or not. Unlike `__torch_function__`, by default there is no `__torch_dispatch__` on Tensor classes. 

**Maintaining the Python dispatch key.** In order to get to the dispatch to Python logic, we must tag Tensors with the `__torch_dispatch__` magic method with the newly added Python dispatch key (separated from PythonFuncTorch to allow for a transitional period while they migrate to this mechanism). We expose a new private property `_is_python_dispatch` that assists in debugging if a Tensor is participating in Python dispatch or not. We apply the Python dispatch key the first time a PyObject for a Tensor is constructed (THPVariable_NewWithVar), testing if `__torch_dispatch__` exists with  then newly added `check_has_torch_dispatch`. 

**Shallow copy and detach.** For the simple examples tested in this PR, most creations of Tensor route through the dispatcher. The exception to this is `shallow_copy_and_detach`, which bypasses the dispatcher and is used when saving tensors for backwards. When a Tensor is Python dispatch, we override the behavior of `shallow_copy_and_detach` to instead directly call into `__torch_dispatch__` to perform a `detach` operation (in the same way it would be invoked if you called `detach` directly). Because this Python call is triggered directly from c10::TensorImpl, it must be indirected through `PyInterpreter::detach`, which is the general mechanism for dynamic dispatching to the Python interpreter associated with a TensorImpl.

**torchdeploy compatibility.** The dispatch to Python logic cannot be directly registered to the dispatcher as it is compiled in the Python library, which will get loaded multiple times per torchdeploy interpreter. Thus, we must employ a two phase process. First, we register a fallback inside a non-Python library (aten/src/ATen/core/PythonFallbackKernel.cpp). Its job is to determine the appropriate PyInterpreter to handle the Python dispatch by going through all of the arguments and finding the first argument that has a PyObject/PyInterpreter. With this PyInterpreter, it makes another dynamic dispatch via "dispatch" which will go to the correct torchdeploy interpreter to handle dispatching to actual Python.

**Testing.** We provide a simple example of a LoggingTensor for testing, which can be used to generate TorchScript-like traces to observe what operations are being called when a Tensor is invoked. Although a LoggingTensor would be better implemented via an is-a relationship rather than a has-a relationship (as is done in the test), we've done it this way to show that arbitrarily complex compositions of tensors inside a tensor work properly.

**Known limitations.**

* We haven't adjusted any operator code, so some patterns may not work (as they lose the Python subclass in an unrecoverable way)
* It's not possible yet to redispatch excluding `__torch_dispatch__`
* `__torch_function__` must be explicitly disabled with `_disabled_torch_function_impl` otherwise things don't work quite correctly (in particular, what is being disabled is default subclass preservation behavior.)
* We don't ever populate kwargs, even when an argument is kwarg-only

Signed-off-by: Edward Z. Yang <ezyang@fb.com>

Differential Revision: [D29017912](https://our.internmc.facebook.com/intern/diff/D29017912)

[ghstack-poisoned]
See #59049

There are some moving parts to this PR, I'll structure this explanation so the straightforward parts go first, and then the less straightforward parts.

**The actual dispatch to Python.** The core logic of dispatch to Python lives in `concrete_dispatch_fn` in `torch/csrc/autograd/python_variable.cpp`. It takes the input IValue stack, scans all the arguments for Tensor arguments, and defers most of the heavy lifting to `handle_torch_function_no_python_arg_parser` which actually does all of the logic for calling out to torch dispatch (in particular, this function handles multiple dispatch situations for you). Because we have a different function name than regular `__torch_function__` handling, `handle_torch_function_no_python_arg_parser` is generalized to accept a magic method name to look for when testing if Tensors have custom handling or not. Unlike `__torch_function__`, by default there is no `__torch_dispatch__` on Tensor classes. 

**Maintaining the Python dispatch key.** In order to get to the dispatch to Python logic, we must tag Tensors with the `__torch_dispatch__` magic method with the newly added Python dispatch key (separated from PythonFuncTorch to allow for a transitional period while they migrate to this mechanism). We expose a new private property `_is_python_dispatch` that assists in debugging if a Tensor is participating in Python dispatch or not. We apply the Python dispatch key the first time a PyObject for a Tensor is constructed (THPVariable_NewWithVar), testing if `__torch_dispatch__` exists with  then newly added `check_has_torch_dispatch`. 

**Shallow copy and detach.** For the simple examples tested in this PR, most creations of Tensor route through the dispatcher. The exception to this is `shallow_copy_and_detach`, which bypasses the dispatcher and is used when saving tensors for backwards. When a Tensor is Python dispatch, we override the behavior of `shallow_copy_and_detach` to instead directly call into `__torch_dispatch__` to perform a `detach` operation (in the same way it would be invoked if you called `detach` directly). Because this Python call is triggered directly from c10::TensorImpl, it must be indirected through `PyInterpreter::detach`, which is the general mechanism for dynamic dispatching to the Python interpreter associated with a TensorImpl.

**torchdeploy compatibility.** The dispatch to Python logic cannot be directly registered to the dispatcher as it is compiled in the Python library, which will get loaded multiple times per torchdeploy interpreter. Thus, we must employ a two phase process. First, we register a fallback inside a non-Python library (aten/src/ATen/core/PythonFallbackKernel.cpp). Its job is to determine the appropriate PyInterpreter to handle the Python dispatch by going through all of the arguments and finding the first argument that has a PyObject/PyInterpreter. With this PyInterpreter, it makes another dynamic dispatch via "dispatch" which will go to the correct torchdeploy interpreter to handle dispatching to actual Python.

**Testing.** We provide a simple example of a LoggingTensor for testing, which can be used to generate TorchScript-like traces to observe what operations are being called when a Tensor is invoked. Although a LoggingTensor would be better implemented via an is-a relationship rather than a has-a relationship (as is done in the test), we've done it this way to show that arbitrarily complex compositions of tensors inside a tensor work properly.

**Known limitations.**

* We haven't adjusted any operator code, so some patterns may not work (as they lose the Python subclass in an unrecoverable way)
* It's not possible yet to redispatch excluding `__torch_dispatch__`
* `__torch_function__` must be explicitly disabled with `_disabled_torch_function_impl` otherwise things don't work quite correctly (in particular, what is being disabled is default subclass preservation behavior.)
* We don't ever populate kwargs, even when an argument is kwarg-only

Signed-off-by: Edward Z. Yang <ezyang@fb.com>

Differential Revision: [D29017912](https://our.internmc.facebook.com/intern/diff/D29017912)

[ghstack-poisoned]
ezyang added a commit that referenced this pull request Jun 21, 2021
See #59049

Signed-off-by: Edward Z. Yang <ezyang@fb.com>

ghstack-source-id: 16008cb
Pull Request resolved: #59760
@ezyang
Copy link
Contributor Author

ezyang commented Jun 21, 2021

@ezyang has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

See #59049

There are some moving parts to this PR, I'll structure this explanation so the straightforward parts go first, and then the less straightforward parts.

**The actual dispatch to Python.** The core logic of dispatch to Python lives in `concrete_dispatch_fn` in `torch/csrc/autograd/python_variable.cpp`. It takes the input IValue stack, scans all the arguments for Tensor arguments, and defers most of the heavy lifting to `handle_torch_function_no_python_arg_parser` which actually does all of the logic for calling out to torch dispatch (in particular, this function handles multiple dispatch situations for you). Because we have a different function name than regular `__torch_function__` handling, `handle_torch_function_no_python_arg_parser` is generalized to accept a magic method name to look for when testing if Tensors have custom handling or not. Unlike `__torch_function__`, by default there is no `__torch_dispatch__` on Tensor classes. 

**Maintaining the Python dispatch key.** In order to get to the dispatch to Python logic, we must tag Tensors with the `__torch_dispatch__` magic method with the newly added Python dispatch key (separated from PythonFuncTorch to allow for a transitional period while they migrate to this mechanism). We expose a new private property `_is_python_dispatch` that assists in debugging if a Tensor is participating in Python dispatch or not. We apply the Python dispatch key the first time a PyObject for a Tensor is constructed (THPVariable_NewWithVar), testing if `__torch_dispatch__` exists with  then newly added `check_has_torch_dispatch`. 

**Shallow copy and detach.** For the simple examples tested in this PR, most creations of Tensor route through the dispatcher. The exception to this is `shallow_copy_and_detach`, which bypasses the dispatcher and is used when saving tensors for backwards. When a Tensor is Python dispatch, we override the behavior of `shallow_copy_and_detach` to instead directly call into `__torch_dispatch__` to perform a `detach` operation (in the same way it would be invoked if you called `detach` directly). Because this Python call is triggered directly from c10::TensorImpl, it must be indirected through `PyInterpreter::detach`, which is the general mechanism for dynamic dispatching to the Python interpreter associated with a TensorImpl.

**torchdeploy compatibility.** The dispatch to Python logic cannot be directly registered to the dispatcher as it is compiled in the Python library, which will get loaded multiple times per torchdeploy interpreter. Thus, we must employ a two phase process. First, we register a fallback inside a non-Python library (aten/src/ATen/core/PythonFallbackKernel.cpp). Its job is to determine the appropriate PyInterpreter to handle the Python dispatch by going through all of the arguments and finding the first argument that has a PyObject/PyInterpreter. With this PyInterpreter, it makes another dynamic dispatch via "dispatch" which will go to the correct torchdeploy interpreter to handle dispatching to actual Python.

**Testing.** We provide a simple example of a LoggingTensor for testing, which can be used to generate TorchScript-like traces to observe what operations are being called when a Tensor is invoked. Although a LoggingTensor would be better implemented via an is-a relationship rather than a has-a relationship (as is done in the test), we've done it this way to show that arbitrarily complex compositions of tensors inside a tensor work properly.

**Known limitations.**

* We haven't adjusted any operator code, so some patterns may not work (as they lose the Python subclass in an unrecoverable way)
* `__torch_function__` must be explicitly disabled with `_disabled_torch_function_impl` otherwise things don't work quite correctly (in particular, what is being disabled is default subclass preservation behavior.)
* We don't ever populate kwargs, even when an argument is kwarg-only

Signed-off-by: Edward Z. Yang <ezyang@fb.com>

Differential Revision: [D29017912](https://our.internmc.facebook.com/intern/diff/D29017912)

[ghstack-poisoned]
ezyang added a commit that referenced this pull request Jun 23, 2021
See #59049

Signed-off-by: Edward Z. Yang <ezyang@fb.com>

ghstack-source-id: f11d532
Pull Request resolved: #59760
@ezyang
Copy link
Contributor Author

ezyang commented Jun 23, 2021

@ezyang has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

@ezyang ezyang requested a review from bhosmer June 24, 2021 17:58

# TODO: move this into library proper
@contextlib.contextmanager
def no_dispatch() -> Iterator[None]:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just want to confirm two questions around __torch_dispatch__ best practices, assuming the two common use cases of:

  • "is a" (where the subclass uses itself as the actual tensor, and doesn't contain any tensor members)
  • "has a" (where the subclass just treats itself as a meta tensor, and has at least one tensor member that it calls ops from __torch_dispatch__ on)

(a) a context manager like no_dispatch() isn't needed in the vanilla "has a" case, and is only important in the "is a" case. Assuming that the member tensors are just ordinary tensors and not themselves subclasses, we don't have to worry about recursively hitting the python key again.

(b) Depending on if/how we decide to expose a redispatching API, it seems like that could probably replace the need for this context manager (at least in the vanilla "is a" case)? Since DispatchKey::Python has lower precedence than all other non-backend keys, we'd expect to go through all the other behavioral keys before hitting torch_dispatch, so the common case would probably be to just directly dispatch to the backend kernel.

I.e. instead of

with no_dispatch():
            output = func(*args)

We have something like

output = func.redispatch(*args) # dispatches below Python key, to backend kernel

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, agreed on both counts.

Comment on lines +128 to +129
// TODO: test if Python key is disabled
PyObject_FastGetAttrString(obj, "__torch_dispatch__").ptr() != nullptr
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this TODO happening now or later?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

actually I think this TODO is obsolete

Copy link
Contributor

@zou3519 zou3519 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Got through everything -- the PR body and Ed's torchdeploy podcast were helpful in figuring out exactly what was going on. There should probably be a TODO/followup item for what happens with operations on tensor.data (we discussed offline that tensor.data uses shallow_copy_and_detach to create a Tensor with a new version counter))

py::handle torch_api_function = py::module::import("torch").attr("ops").attr(ns).attr(func_name);
std::string module_name_str = "torch.ops." + ns_str;

for (int64_t idx = 0; idx < arguments.size(); idx++) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: for (const auto& ivalue : arguments)

}
} else if (ivalue.isList()) {
const auto& list = ivalue.toListRef();
for (int64_t jdx = 0; jdx < list.size(); jdx++) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: for (const auto& nv : list)


py::gil_scoped_acquire g;

std::vector<py::handle> overloaded_args;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: could be SmallVector (most ops accept 4 or fewer tensors)

torch::jit::push(stack, torch::jit::toIValue(out.ptr(), op.schema().returns()[0].type()));
} else {
auto outs = py::cast<py::sequence>(out);
for (unsigned idx = 0; idx < outs.size(); idx++) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: for (const auto idx : c10::irange(0, outs.size()) makes it so that you don't have to worry about the picking the correct type (unsigned)`

# 3. Enter dispatcher, wind your way through Autograd
# 4. Hit Python dispatch key, call __torch_dispatch__

# TODO: TensorBase should work
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is the TODO still relevant? I see a subclass of TensorBase down below: class A(torch._C._TensorBase):

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

TensorBase still doesn't work for some autograd API functions where they test isinstance(a, Tensor) rather than isinstance(a, TensorBase)

@facebook-github-bot
Copy link
Contributor

@ezyang merged this pull request in aacc722.

asuhan pushed a commit to asuhan/pytorch that referenced this pull request Jun 28, 2021
Summary:
Pull Request resolved: pytorch#59760

See pytorch#59049

There are some moving parts to this PR, I'll structure this explanation so the straightforward parts go first, and then the less straightforward parts.

**The actual dispatch to Python.** The core logic of dispatch to Python lives in `concrete_dispatch_fn` in `torch/csrc/autograd/python_variable.cpp`. It takes the input IValue stack, scans all the arguments for Tensor arguments, and defers most of the heavy lifting to `handle_torch_function_no_python_arg_parser` which actually does all of the logic for calling out to torch dispatch (in particular, this function handles multiple dispatch situations for you). Because we have a different function name than regular `__torch_function__` handling, `handle_torch_function_no_python_arg_parser` is generalized to accept a magic method name to look for when testing if Tensors have custom handling or not. Unlike `__torch_function__`, by default there is no `__torch_dispatch__` on Tensor classes.

**Maintaining the Python dispatch key.** In order to get to the dispatch to Python logic, we must tag Tensors with the `__torch_dispatch__` magic method with the newly added Python dispatch key (separated from PythonFuncTorch to allow for a transitional period while they migrate to this mechanism). We expose a new private property `_is_python_dispatch` that assists in debugging if a Tensor is participating in Python dispatch or not. We apply the Python dispatch key the first time a PyObject for a Tensor is constructed (THPVariable_NewWithVar), testing if `__torch_dispatch__` exists with  then newly added `check_has_torch_dispatch`.

**Shallow copy and detach.** For the simple examples tested in this PR, most creations of Tensor route through the dispatcher. The exception to this is `shallow_copy_and_detach`, which bypasses the dispatcher and is used when saving tensors for backwards. When a Tensor is Python dispatch, we override the behavior of `shallow_copy_and_detach` to instead directly call into `__torch_dispatch__` to perform a `detach` operation (in the same way it would be invoked if you called `detach` directly). Because this Python call is triggered directly from c10::TensorImpl, it must be indirected through `PyInterpreter::detach`, which is the general mechanism for dynamic dispatching to the Python interpreter associated with a TensorImpl.

**torchdeploy compatibility.** The dispatch to Python logic cannot be directly registered to the dispatcher as it is compiled in the Python library, which will get loaded multiple times per torchdeploy interpreter. Thus, we must employ a two phase process. First, we register a fallback inside a non-Python library (aten/src/ATen/core/PythonFallbackKernel.cpp). Its job is to determine the appropriate PyInterpreter to handle the Python dispatch by going through all of the arguments and finding the first argument that has a PyObject/PyInterpreter. With this PyInterpreter, it makes another dynamic dispatch via "dispatch" which will go to the correct torchdeploy interpreter to handle dispatching to actual Python.

**Testing.** We provide a simple example of a LoggingTensor for testing, which can be used to generate TorchScript-like traces to observe what operations are being called when a Tensor is invoked. Although a LoggingTensor would be better implemented via an is-a relationship rather than a has-a relationship (as is done in the test), we've done it this way to show that arbitrarily complex compositions of tensors inside a tensor work properly.

**Known limitations.**

* We haven't adjusted any operator code, so some patterns may not work (as they lose the Python subclass in an unrecoverable way)
* `__torch_function__` must be explicitly disabled with `_disabled_torch_function_impl` otherwise things don't work quite correctly (in particular, what is being disabled is default subclass preservation behavior.)
* We don't ever populate kwargs, even when an argument is kwarg-only

Signed-off-by: Edward Z. Yang <ezyang@fb.com>

Differential Revision:
D29017912
D29017912

Test Plan: Imported from OSS

Reviewed By: bdhirsh

Pulled By: ezyang

fbshipit-source-id: a67714d9e541d09203a8cfc85345b8967db86238
@facebook-github-bot facebook-github-bot deleted the gh/ezyang/1038/head branch June 29, 2021 14:22
asuhan pushed a commit that referenced this pull request Jun 30, 2021
Summary:
Pull Request resolved: #59760

See #59049

There are some moving parts to this PR, I'll structure this explanation so the straightforward parts go first, and then the less straightforward parts.

**The actual dispatch to Python.** The core logic of dispatch to Python lives in `concrete_dispatch_fn` in `torch/csrc/autograd/python_variable.cpp`. It takes the input IValue stack, scans all the arguments for Tensor arguments, and defers most of the heavy lifting to `handle_torch_function_no_python_arg_parser` which actually does all of the logic for calling out to torch dispatch (in particular, this function handles multiple dispatch situations for you). Because we have a different function name than regular `__torch_function__` handling, `handle_torch_function_no_python_arg_parser` is generalized to accept a magic method name to look for when testing if Tensors have custom handling or not. Unlike `__torch_function__`, by default there is no `__torch_dispatch__` on Tensor classes.

**Maintaining the Python dispatch key.** In order to get to the dispatch to Python logic, we must tag Tensors with the `__torch_dispatch__` magic method with the newly added Python dispatch key (separated from PythonFuncTorch to allow for a transitional period while they migrate to this mechanism). We expose a new private property `_is_python_dispatch` that assists in debugging if a Tensor is participating in Python dispatch or not. We apply the Python dispatch key the first time a PyObject for a Tensor is constructed (THPVariable_NewWithVar), testing if `__torch_dispatch__` exists with  then newly added `check_has_torch_dispatch`.

**Shallow copy and detach.** For the simple examples tested in this PR, most creations of Tensor route through the dispatcher. The exception to this is `shallow_copy_and_detach`, which bypasses the dispatcher and is used when saving tensors for backwards. When a Tensor is Python dispatch, we override the behavior of `shallow_copy_and_detach` to instead directly call into `__torch_dispatch__` to perform a `detach` operation (in the same way it would be invoked if you called `detach` directly). Because this Python call is triggered directly from c10::TensorImpl, it must be indirected through `PyInterpreter::detach`, which is the general mechanism for dynamic dispatching to the Python interpreter associated with a TensorImpl.

**torchdeploy compatibility.** The dispatch to Python logic cannot be directly registered to the dispatcher as it is compiled in the Python library, which will get loaded multiple times per torchdeploy interpreter. Thus, we must employ a two phase process. First, we register a fallback inside a non-Python library (aten/src/ATen/core/PythonFallbackKernel.cpp). Its job is to determine the appropriate PyInterpreter to handle the Python dispatch by going through all of the arguments and finding the first argument that has a PyObject/PyInterpreter. With this PyInterpreter, it makes another dynamic dispatch via "dispatch" which will go to the correct torchdeploy interpreter to handle dispatching to actual Python.

**Testing.** We provide a simple example of a LoggingTensor for testing, which can be used to generate TorchScript-like traces to observe what operations are being called when a Tensor is invoked. Although a LoggingTensor would be better implemented via an is-a relationship rather than a has-a relationship (as is done in the test), we've done it this way to show that arbitrarily complex compositions of tensors inside a tensor work properly.

**Known limitations.**

* We haven't adjusted any operator code, so some patterns may not work (as they lose the Python subclass in an unrecoverable way)
* `__torch_function__` must be explicitly disabled with `_disabled_torch_function_impl` otherwise things don't work quite correctly (in particular, what is being disabled is default subclass preservation behavior.)
* We don't ever populate kwargs, even when an argument is kwarg-only

Signed-off-by: Edward Z. Yang <ezyang@fb.com>

Differential Revision:
D29017912
D29017912

Test Plan: Imported from OSS

Reviewed By: bdhirsh

Pulled By: ezyang

fbshipit-source-id: a67714d9e541d09203a8cfc85345b8967db86238
zou3519 added a commit that referenced this pull request Jul 9, 2021
functorch is unable to `vmap(grad(f))` when `f` contains a `.contiguous`
call. This is because `.contiguous` (when it is not a no-op) decomposes
to `.copy_` under grad and the `.copy_` is not compatible with vmap.

The fix for this is to have `.contiguous` call `.clone` instead of
`.copy_`. `clone` is a primitive w.r.t. to autograd, so `grad`
decomposes contiguous into clone.

Perf testing (forward pass)
- [script and
output](https://gist.github.com/zou3519/294f583b9c5d7bdf234d5295f97fb02e)
- The instruction count increased from 774479 to 781379. This is because
we're now calling .clone(), which does an additional dispatch. We could
optimize the implementation of clone() to not dispatch on .copy_() in
the future if we really care about this.

Perf testing (backward pass)
- [script and
output](https://gist.github.com/zou3519/6fbdb121de6342334192d55c8a72276a)
- The instruction count decreased from 5402648 to 5335977. This is
because the [backward for
.clone](https://github.com/pytorch/pytorch/blob/9b908ab0d0a947d89ac3137f8c4a05a87c35f568/tools/autograd/derivatives.yaml#L383)
is a lot simpler than the [backward for
copy_](https://github.com/pytorch/pytorch/blob/9b908ab0d0a947d89ac3137f8c4a05a87c35f568/torch/csrc/autograd/functions/tensor.cpp#L37-L41)
- The backward for .clone() and .copy_() end up doing the same thing for
contiguous (from reading the code above, they both do no-op copies).

Test Plan:
- wait for existing tests (test_view_ops have the tests)
- functorch isn't tested in PyTorch CI yet.
- Taking suggestions on how to write a test for this. I'm thinking we
could use LoggingTensor from #59760 (because it logs underneath
autograd) and test that clone is called instead of copy_ but I didn't
think too hard about it

[ghstack-poisoned]
zou3519 added a commit that referenced this pull request Jul 9, 2021
functorch is unable to `vmap(grad(f))` when `f` contains a `.contiguous`
call. This is because `.contiguous` (when it is not a no-op) decomposes
to `.copy_` under grad and the `.copy_` is not compatible with vmap.

The fix for this is to have `.contiguous` call `.clone` instead of
`.copy_`. `clone` is a primitive w.r.t. to autograd, so `grad`
decomposes contiguous into clone.

Perf testing (forward pass)
- [script and
output](https://gist.github.com/zou3519/294f583b9c5d7bdf234d5295f97fb02e)
- The instruction count increased from 774479 to 781379. This is because
we're now calling .clone(), which does an additional dispatch. We could
optimize the implementation of clone() to not dispatch on .copy_() in
the future if we really care about this.

Perf testing (backward pass)
- [script and
output](https://gist.github.com/zou3519/6fbdb121de6342334192d55c8a72276a)
- The instruction count decreased from 5402648 to 5335977. This is
because the [backward for
.clone](https://github.com/pytorch/pytorch/blob/9b908ab0d0a947d89ac3137f8c4a05a87c35f568/tools/autograd/derivatives.yaml#L383)
is a lot simpler than the [backward for
copy_](https://github.com/pytorch/pytorch/blob/9b908ab0d0a947d89ac3137f8c4a05a87c35f568/torch/csrc/autograd/functions/tensor.cpp#L37-L41)
- The backward for .clone() and .copy_() end up doing the same thing for
contiguous (from reading the code above, they both do no-op copies).

Test Plan:
- wait for existing tests (test_view_ops have the tests)
- functorch isn't tested in PyTorch CI yet.
- Taking suggestions on how to write a test for this. I'm thinking we
could use LoggingTensor from #59760 (because it logs underneath
autograd) and test that clone is called instead of copy_ but I didn't
want to refactor it into a utility

[ghstack-poisoned]
zou3519 added a commit that referenced this pull request Jul 9, 2021
functorch is unable to `vmap(grad(f))` when `f` contains a `.contiguous`
call. This is because `.contiguous` (when it is not a no-op) decomposes
to `.copy_` under grad and the `.copy_` is not compatible with vmap.

The fix for this is to have `.contiguous` call `.clone` instead of
`.copy_`. `clone` is a primitive w.r.t. to autograd, so `grad`
decomposes contiguous into clone.

Perf testing (forward pass)
- [script and
output](https://gist.github.com/zou3519/294f583b9c5d7bdf234d5295f97fb02e)
- The instruction count increased from 774479 to 781379. This is because
we're now calling .clone(), which does an additional dispatch. We could
optimize the implementation of clone() to not dispatch on .copy_() in
the future if we really care about this.

Perf testing (backward pass)
- [script and
output](https://gist.github.com/zou3519/6fbdb121de6342334192d55c8a72276a)
- The instruction count decreased from 5402648 to 5335977. This is
because the [backward for
.clone](https://github.com/pytorch/pytorch/blob/9b908ab0d0a947d89ac3137f8c4a05a87c35f568/tools/autograd/derivatives.yaml#L383)
is a lot simpler than the [backward for
copy_](https://github.com/pytorch/pytorch/blob/9b908ab0d0a947d89ac3137f8c4a05a87c35f568/torch/csrc/autograd/functions/tensor.cpp#L37-L41)
- The backward for .clone() and .copy_() end up doing the same thing for
contiguous (from reading the code above, they both do no-op copies).

Test Plan:
- wait for existing tests (test_view_ops have the tests)
- functorch isn't tested in PyTorch CI yet.
- Taking suggestions on how to write a test for this. I'm thinking we
could use LoggingTensor from #59760 (because it logs underneath
autograd) and test that clone is called instead of copy_ but I didn't
want to refactor it into a utility

ghstack-source-id: 0ef8c6c
Pull Request resolved: #61456
facebook-github-bot pushed a commit that referenced this pull request Jul 12, 2021
Summary:
Pull Request resolved: #61456

functorch is unable to `vmap(grad(f))` when `f` contains a `.contiguous`
call. This is because `.contiguous` (when it is not a no-op) decomposes
to `.copy_` under grad and the `.copy_` is not compatible with vmap.

The fix for this is to have `.contiguous` call `.clone` instead of
`.copy_`. `clone` is a primitive w.r.t. to autograd, so `grad`
decomposes contiguous into clone.

Perf testing (forward pass)
- [script and
output](https://gist.github.com/zou3519/294f583b9c5d7bdf234d5295f97fb02e)
- The instruction count increased from 774479 to 781379. This is because
we're now calling .clone(), which does an additional dispatch. We could
optimize the implementation of clone() to not dispatch on .copy_() in
the future if we really care about this.

Perf testing (backward pass)
- [script and
output](https://gist.github.com/zou3519/6fbdb121de6342334192d55c8a72276a)
- The instruction count decreased from 5402648 to 5335977. This is
because the [backward for
.clone](https://github.com/pytorch/pytorch/blob/9b908ab0d0a947d89ac3137f8c4a05a87c35f568/tools/autograd/derivatives.yaml#L383)
is a lot simpler than the [backward for
copy_](https://github.com/pytorch/pytorch/blob/9b908ab0d0a947d89ac3137f8c4a05a87c35f568/torch/csrc/autograd/functions/tensor.cpp#L37-L41)
- The backward for .clone() and .copy_() end up doing the same thing for
contiguous (from reading the code above, they both do no-op copies).

Test Plan:
- wait for existing tests (test_view_ops have the tests)
- functorch isn't tested in PyTorch CI yet.
- Taking suggestions on how to write a test for this. I'm thinking we
could use LoggingTensor from #59760 (because it logs underneath
autograd) and test that clone is called instead of copy_ but I didn't
want to refactor it into a utility

Reviewed By: soulitzer

Differential Revision: D29636859

Pulled By: zou3519

fbshipit-source-id: 97eb56bfae1c4bb31612dc9d06536019f21d69a6
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants