-
Notifications
You must be signed in to change notification settings - Fork 88
RFC-0011-InferenceMode #17
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
RFC: pytorch/rfcs#17 Differential Revision: [D26973911](https://our.internmc.facebook.com/intern/diff/D26973911) [ghstack-poisoned]
RFC: pytorch/rfcs#17 ghstack-source-id: 20520fc Pull Request resolved: #53343
RFC-0011-InferenceMode.md
Outdated
Note: a large part of this RFC will become "InferenceMode" documentation once it's finalized. | ||
|
||
## Goals: | ||
- Provide a RAII in C++ and a context manager in Python frontend to switch between inference mode and normal mode, with the following constraints: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The Python frontend context manager isn't that important, right? Because this is oriented to performance use cases where you ought to be in C++ only anyway (it's good that it is possible and maybe some time we should add it, but I wouldn't say it's a primary goal)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yep agreed, didn't plan to add that until we become stable on the C++ end. Mentioning it here just to make sure it's possible. :D
RFC-0011-InferenceMode.md
Outdated
## Goals: | ||
- Provide a RAII in C++ and a context manager in Python frontend to switch between inference mode and normal mode, with the following constraints: | ||
- correctness is always guaranteed. (compared to `AutoNonVariableType` which has risks producing silent wrong result.) | ||
- performance of infenrence mode should match current existing `AutoNonVariableTypeMode` which is widely used in prod. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: inference
Signed-off-by: Edward Z. Yang <ezyang@fb.com>
RFC-0011-InferenceMode.md
Outdated
- Make `AutoNonVariableTypeMode` an internal only API, replace all callsites of `AutoNonVariableTypeMode` outside pytorch codebase with the new `InferenceMode`. | ||
|
||
## Non-goals: | ||
- Match the theoretical best inference performance which can be achieved by stripping all autograd related stuff at build time (not flexible). |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actually, are we sure about this? If you write code solely in InferenceMode, with no interaction with non-InferenceMode tensors, it seems to me that theoretical best performance should be attainable (since we never have to hit the safety code for the intermixing situation).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
oh yes, without the intermixing situation (and maybe some startup cost initializing dispatch table) the performance should be attainable. :D
RFC-0011-InferenceMode.md
Outdated
The following modes are ranked from slowest to fastest in speed, and from the most flexible to the most restrictive in what users can do. | ||
|
||
* Normal Mode: we create the graph for all Tensors that require gradients, always track view and inplace even they don't require gradients. | ||
* GradMode disabled: we never create the graph, still track all views and inplace. User code always succeeds to properly track gradients. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@albanD I kind of want a little more meat here, along the lines of the design principle @apaszke explicated at pytorch/pytorch#12502 (review)
One possible angle here is that no_grad
is strictly local: the effects of this context manager affect what happens inside the block, but everything outside of the block is still fair game for full autograd support. (This is not as "hard" a design philosophy, but it's my best understanding right now.) Inference mode, on the other hand, is OK with letting the implications of this context escape; its goal is performance, and so if there is no way to implement something fast without affecting stuff outside of the inference mode block, we will just let it leak out.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That good old discussion haha
As mentioned there, this is not really true for gradmode as views created there will lead to some side effects even outside the block.
But I do agree with you that when using this mode, all ops performed outside of it will always work. Which is not true for inference mode.
I am not sure what kind of details you're looking for here?
cc @ssnl, three years later we might actually merge a version of your pytorch/pytorch#12502 haha! |
RFC-0011-InferenceMode.md
Outdated
return input_base.expand(size_vec, implicit); | ||
}; | ||
} | ||
auto result = as_view(/* base */ self, /* output */ _tmp, /* is_bw_differentiable */ true, /* is_fw_differentiable */ true, /* view_func */ func, /* creatio |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
copy paste problem
RFC-0011-InferenceMode.md
Outdated
- **Normal tensor** has both Autograd & InplaceOrView keys. This includes both `requires_grad=true` and `requires_grad=false` tensors. (see [Ideal end state] section for more details). | ||
- Additional notes: | ||
- All Inference tensors are created in inference mode, but not all of the tensors created in inference mode are inference tensors. For example, a view of normal tensor created in inference mode is still a normal tensor (but with special `creation_meta`!). | ||
- (Autograd & !InplaceOrView) and (!Autogad & InplaceOrView) are invalid states, we don't have such tensors. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe we should have called this Autograd and NoAutograd LOL
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
you mean autograd tensors and NoAutograd tensors? I like those names but they sounds too related to the GradMode which will be confusing to users :(
… error handling" RFC: pytorch/rfcs#17 Differential Revision: [D26973911](https://our.internmc.facebook.com/intern/diff/D26973911) [ghstack-poisoned]
RFC: pytorch/rfcs#17 Differential Revision: [D26973911](https://our.internmc.facebook.com/intern/diff/D26973911) [ghstack-poisoned]
… error handling" RFC: pytorch/rfcs#17 Differential Revision: [D26973911](https://our.internmc.facebook.com/intern/diff/D26973911) [ghstack-poisoned]
RFC: pytorch/rfcs#17 Differential Revision: [D26973911](https://our.internmc.facebook.com/intern/diff/D26973911) [ghstack-poisoned]
… error handling" RFC: pytorch/rfcs#17 Differential Revision: [D26973911](https://our.internmc.facebook.com/intern/diff/D26973911) [ghstack-poisoned]
RFC: pytorch/rfcs#17 Differential Revision: [D26973911](https://our.internmc.facebook.com/intern/diff/D26973911) [ghstack-poisoned]
… error handling" RFC: pytorch/rfcs#17 Differential Revision: [D26973911](https://our.internmc.facebook.com/intern/diff/D26973911) [ghstack-poisoned]
RFC: pytorch/rfcs#17 Differential Revision: [D26973911](https://our.internmc.facebook.com/intern/diff/D26973911) [ghstack-poisoned]
… error handling" RFC: pytorch/rfcs#17 Differential Revision: [D26973911](https://our.internmc.facebook.com/intern/diff/D26973911) [ghstack-poisoned]
# Why is the change needed? When audio chunks are 40 ms long, the inplace operation that replaces `NaN` with zeros currently fails with the following `RuntimeError`: ``` RuntimeError: Inplace update to inference tensor outside InferenceMode is not allowed.You can make a clone to get a normal tensor before doing inplace update.See pytorch/rfcs#17 for more details. ``` For more information regarding the torch inference mode, refer to [the Torch guide](https://pytorch.org/docs/stable/generated/torch.inference_mode.html). # What changes does the patch introduce? Uses inference mode when performing all operations on tensors. # How was this patch tested? Manual execution on a problematic SRT, UTs already exist but do not cover tensors coming from the Sonar encoder.
Error from RuntimeError: Inplace update to inference tensor outside InferenceMode is not allowed.You can make a clone to get a normal tensor before doing inplace update.See pytorch/rfcs#17 for more details
…de InferenceMode is not allowed.You can make a clone to get a normal tensor before doing inplace update.See pytorch/rfcs#17 for more details.) occurs when attempting to merge LoRA while inference mode is enabled.
Rendered