KEMBAR78
stop non-differentiable values from being materialized in aotautograd by Chillee · Pull Request #110721 · pytorch/pytorch · GitHub
Skip to content

Conversation

@Chillee
Copy link
Collaborator

@Chillee Chillee commented Oct 6, 2023

@pytorch-bot
Copy link

pytorch-bot bot commented Oct 6, 2023

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/110721

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit a10ef1b with merge base e3bf500 (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

Chillee added a commit that referenced this pull request Oct 6, 2023
@Chillee
Copy link
Collaborator Author

Chillee commented Oct 7, 2023

Interestingly this seems to have a significant impact in our benchmarks. I'm guessing primarily due to eliding a bunch of memory copies in the backwards pass.

image

For an example model (levit_128), this is how the trace looks like prior to this change. This change results in the green area and second pink area disappearing.

image

@ezyang
Copy link
Contributor

ezyang commented Oct 7, 2023

This seems plausible but you still have test errors

@Chillee
Copy link
Collaborator Author

Chillee commented Oct 7, 2023

Where do you see test errors 🤔

@ezyang ezyang removed their request for review October 9, 2023 16:28
@ezyang
Copy link
Contributor

ezyang commented Oct 9, 2023

I'm letting Brian review this

@Chillee
Copy link
Collaborator Author

Chillee commented Oct 9, 2023

@pytorchbot merge

@pytorch-bot
Copy link

pytorch-bot bot commented Oct 9, 2023

This PR needs to be approved by an authorized maintainer before merge.

inp_tangents_filtered = [
x
for x, info_idx in zip(inp_tangents, mutated_inp_indices)
if input_info[info_idx].mutates_data and input_info[info_idx].requires_grad
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm, I think the bit of the filter on `input_info[info_idx].requires_grad should be unnecessary now?

Previously, inp_tangents corresponded to "every forward input that had a (data or metadata) mutation". We had to filter this down to the actual inputs to our backward graph that correspond to user forward inputs, which are fw inputs that had a data mutation and require grad. (Actually - double checking, I think we're already correctly filtering out metadata mutations).

But now that you're filtering down to fw inputs that have a mutation and require grad even earlier (as part of constructing the joint fw/bw), we should only have to filter out inputs with metadata-only mutations in this check.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For future discussion, this isn't true. In this case, inp_tangents represents all of the inputs to the backwards pass, which will include values that corresponds to tensors that we don't actually trace with (for example, outputs of the forwards pass that are nondifferentiable).

So we still need to filter them out in this place - the logic is pretty analogous to the existing ones.

@Chillee
Copy link
Collaborator Author

Chillee commented Oct 9, 2023

@pytorchbot merge

@pytorch-bot pytorch-bot bot added the ciflow/trunk Trigger trunk jobs on your pull request label Oct 9, 2023
@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants