-
Notifications
You must be signed in to change notification settings - Fork 25.7k
partial attempt at stopping non-differentiable values from being materialized #110592
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
…rialized [ghstack-poisoned]
# If it is a Tensor, what the dynamic dims are (otherwise is None) | ||
dynamic_dims: Optional[Set[int]] | ||
# requires_grad | ||
requires_grad: bool |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: ViewAndMutationMeta.requires_grad_info
already has this info, since it is of length (# mutated inputs + # user outputs).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Any reason you prefer to have a separate object rather than have it on here? My intuition was that the less lists we have with implicitly matching order the better.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah this is a reasonable q: Let me give context. There are a few other places where we need to distinguish which tensors require gradients:
(1) Affects both mutated inputs and fw outputs (code): At runtime, the compiled forward graph returns both any updated inputs, and any user fw outs. For both of these groups of tensors, we need to know which ones do not require gradients, so we can mark them as non_differentiable.
(2) Affects just mutated inputs (code): During tracing of the joint, if we have a mutated input that requires gradients, we need to clone() it in the forward, so that we can autograd.grad() w.r.t. the input pre-mutation (this is actually kind of sub-optimal since the clone() affects runtime perf and is just to appease the autograd engine, but this case should be rare).
(3) Affects just fw outputs (code): At runtime after our CompiledFunction.apply()
finishes, we need to regenerate fw outs that alias inputs. We need to know if the output alias no longer requires grad, which implies that a detach() happened during tracing (so we need to reapply that detach at runtime).
If we do what you put above and move the requires_grad info directly on OutputInfo
, then we'd have to either:
(a) Also put requires_grad info on InputInfo
objects - but, every piece of code in AOTAutograd that care about requires grad information on inputs would have to manually filter down to the input infos that correspond to mutated inputs
(b) Add another piece of metadata that just correponds to "mutated inputs", that we tack requires_grad info on.
Let me know what you think. (a) Actually might make things easier to reason about (even if it's a bit more boilerplate), but it will be a pretty annoying refactor.
for o, info in zip(flat_f_outs, output_info) | ||
if info.output_type in [OutputType.non_alias, OutputType.unsafe_view_alias, OutputType.custom_function_view] | ||
and issubclass(info.raw_type, torch.Tensor) | ||
and info.requires_grad |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think you can equivalently check this with:
for o, info, requires_grad_info in zip(flat_f_outs, output_info, output_requires_grad_info):
...
and requires_grad_info
… being materialized" [ghstack-poisoned]
… being materialized" [ghstack-poisoned]
… being materialized" [ghstack-poisoned]
… being materialized" [ghstack-poisoned]
return flat_fn(*unpacked_args) | ||
|
||
if config.debug_assert: | ||
if config.debug_assert and False: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yeah these debug asserts are a pain, but I think they're pretty useful. The idea is that after each layer (removing duplicate inputs, replacing aliased inputs with synthetic bases), the metadata from our analysis pass is slightly different.
We could just re-run the analysis pass, but that would require multiple trips through the user forward with our tracing infra (maybe... fake tensor caching is now fast enough that we can just do this? idk). But instead, there are some helper functions that try to convert the metadata manually, and then these debug asserts are used to make sure that we actually got the metadata correct.
… being materialized" [ghstack-poisoned]
Closed in favor of #110721 |
Stack from ghstack (oldest at bottom):