KEMBAR78
partial attempt at stopping non-differentiable values from being materialized by Chillee · Pull Request #110592 · pytorch/pytorch · GitHub
Skip to content

Conversation

@pytorch-bot
Copy link

pytorch-bot bot commented Oct 5, 2023

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/110592

Note: Links to docs will display an error until the docs builds have been completed.

❌ 20 New Failures, 2 Unrelated Failures

As of commit 613e1a3 with merge base 08c7dcd (image):

NEW FAILURES - The following jobs have failed:

FLAKY - The following jobs failed but were likely due to flakiness present on trunk:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

Chillee added a commit that referenced this pull request Oct 5, 2023
@github-actions github-actions bot requested a review from ezyang October 5, 2023 08:52
# If it is a Tensor, what the dynamic dims are (otherwise is None)
dynamic_dims: Optional[Set[int]]
# requires_grad
requires_grad: bool
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: ViewAndMutationMeta.requires_grad_info already has this info, since it is of length (# mutated inputs + # user outputs).

Copy link
Collaborator Author

@Chillee Chillee Oct 5, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Any reason you prefer to have a separate object rather than have it on here? My intuition was that the less lists we have with implicitly matching order the better.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah this is a reasonable q: Let me give context. There are a few other places where we need to distinguish which tensors require gradients:

(1) Affects both mutated inputs and fw outputs (code): At runtime, the compiled forward graph returns both any updated inputs, and any user fw outs. For both of these groups of tensors, we need to know which ones do not require gradients, so we can mark them as non_differentiable.

(2) Affects just mutated inputs (code): During tracing of the joint, if we have a mutated input that requires gradients, we need to clone() it in the forward, so that we can autograd.grad() w.r.t. the input pre-mutation (this is actually kind of sub-optimal since the clone() affects runtime perf and is just to appease the autograd engine, but this case should be rare).

(3) Affects just fw outputs (code): At runtime after our CompiledFunction.apply() finishes, we need to regenerate fw outs that alias inputs. We need to know if the output alias no longer requires grad, which implies that a detach() happened during tracing (so we need to reapply that detach at runtime).

If we do what you put above and move the requires_grad info directly on OutputInfo, then we'd have to either:

(a) Also put requires_grad info on InputInfo objects - but, every piece of code in AOTAutograd that care about requires grad information on inputs would have to manually filter down to the input infos that correspond to mutated inputs

(b) Add another piece of metadata that just correponds to "mutated inputs", that we tack requires_grad info on.

Let me know what you think. (a) Actually might make things easier to reason about (even if it's a bit more boilerplate), but it will be a pretty annoying refactor.

for o, info in zip(flat_f_outs, output_info)
if info.output_type in [OutputType.non_alias, OutputType.unsafe_view_alias, OutputType.custom_function_view]
and issubclass(info.raw_type, torch.Tensor)
and info.requires_grad
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think you can equivalently check this with:

for o, info, requires_grad_info in zip(flat_f_outs, output_info, output_requires_grad_info):
...
and requires_grad_info

Chillee added a commit that referenced this pull request Oct 6, 2023
return flat_fn(*unpacked_args)

if config.debug_assert:
if config.debug_assert and False:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah these debug asserts are a pain, but I think they're pretty useful. The idea is that after each layer (removing duplicate inputs, replacing aliased inputs with synthetic bases), the metadata from our analysis pass is slightly different.

We could just re-run the analysis pass, but that would require multiple trips through the user forward with our tracing infra (maybe... fake tensor caching is now fast enough that we can just do this? idk). But instead, there are some helper functions that try to convert the metadata manually, and then these debug asserts are used to make sure that we actually got the metadata correct.

Chillee added a commit that referenced this pull request Oct 6, 2023
@Chillee
Copy link
Collaborator Author

Chillee commented Oct 6, 2023

Closed in favor of #110721

@Chillee Chillee closed this Oct 7, 2023
@facebook-github-bot facebook-github-bot deleted the gh/chillee/224/head branch November 6, 2023 15:25
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants