KEMBAR78
Fix prod double backward when there are 2+ zeros by guilhermeleobas · Pull Request #113969 · pytorch/pytorch · GitHub
Skip to content

Conversation

@guilhermeleobas
Copy link
Collaborator

@guilhermeleobas guilhermeleobas commented Nov 17, 2023

@pytorch-bot
Copy link

pytorch-bot bot commented Nov 17, 2023

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/113969

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit fe78741 with merge base 0d6d97d (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

guilhermeleobas added a commit that referenced this pull request Nov 17, 2023
ghstack-source-id: 8ee3db3
Pull Request resolved: #113969
@guilhermeleobas guilhermeleobas marked this pull request as ready for review November 18, 2023 00:18
Copy link
Collaborator

@albanD albanD left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Any details?
Also this is most likely going to be a challenging change just for perf reasons.

@guilhermeleobas
Copy link
Collaborator Author

@albanD, motivation for this one is #106789

Copy link
Collaborator

@albanD albanD left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ho ok!
You can update the condition to only be used when double backward is not used then (see condition below)
You can also add a small comment to that effect: when setting up for double backward, we must do the hard work of properly computing the result even though we know it is going to all 0s to ensure the autograd graph is properly created.

Tensor zero_idx = (input == 0).nonzero();
if (zero_idx.sym_numel() == 0) {
return grad * (result / input).conj();
} else if (zero_idx.size(0) > 1) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(!at::GradMode::is_enabled() && zero_idx.size(0) > 1)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the tip, @albanD. I've updated the code.

guilhermeleobas added a commit that referenced this pull request Nov 20, 2023
ghstack-source-id: 68b2b58
Pull Request resolved: #113969
Copy link
Collaborator

@albanD albanD left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks!
I would expect that there is some test that used to fail that is now passing?

@guilhermeleobas
Copy link
Collaborator Author

@albanD, yes! The following fails on PyTorch main but works on this branch:

import torch

x = torch.tensor([2., 3, 0, 0], requires_grad=True)
y = torch.cumprod(x, dim=0)
gx, = torch.autograd.grad(y.sum(), x, create_graph=True)
gy = torch.autograd.grad(gx.sum(), x)
print(gy)

@albanD
Copy link
Collaborator

albanD commented Nov 20, 2023

Ho I meant there should be a test in CI running this.

@guilhermeleobas
Copy link
Collaborator Author

guilhermeleobas commented Nov 20, 2023

The extra input I added on common_method_invocarions should do it, no?

Edit: I will send in a bit the test that would fail without the changes for this input.

python test/test_ops_gradients.py -k test_fn_gradgrad_prod_cpu_float64

Copy link
Collaborator

@albanD albanD left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The extra input I added on common_method_invocarions should do it, no?

That was my question :p So we already do have gradgradcheck check for this input and so adding this input without the update in c++ would fail?
If so, then SGTM!

@soulitzer soulitzer removed their request for review November 20, 2023 22:34
@guilhermeleobas
Copy link
Collaborator Author

The extra input I added on common_method_invocarions should do it, no?

That was my question :p So we already do have gradgradcheck check for this input and so adding this input without the update in c++ would fail? If so, then SGTM!

Yep! python test/test_ops_gradients.py -k test_fn_gradgrad_prod_cpu_float64

Thanks for the review, @albanD. I'll merge this one.

@guilhermeleobas
Copy link
Collaborator Author

@pytorchbot merge

@pytorch-bot pytorch-bot bot added the ciflow/trunk Trigger trunk jobs on your pull request label Nov 20, 2023
@pytorchmergebot
Copy link
Collaborator

Merge failed

Reason: This PR needs a release notes: label
If your changes are user facing and intended to be a part of release notes, please use a label starting with release notes:.

If not, please add the topic: not user facing label.

To add a label, you can comment to pytorchbot, for example
@pytorchbot label "topic: not user facing"

For more information, see
https://github.com/pytorch/pytorch/wiki/PyTorch-AutoLabel-Bot#why-categorize-for-release-notes-and-how-does-it-work.

Details for Dev Infra team Raised by workflow job

@guilhermeleobas guilhermeleobas added module: autograd Related to torch.autograd, and the autograd engine in general and removed module: autograd Related to torch.autograd, and the autograd engine in general labels Nov 20, 2023
@guilhermeleobas guilhermeleobas added the release notes: autograd release notes category label Nov 20, 2023
@guilhermeleobas
Copy link
Collaborator Author

@pytorchbot merge

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

@facebook-github-bot facebook-github-bot deleted the gh/guilhermeleobas/13/head branch November 24, 2023 15:27
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ciflow/trunk Trigger trunk jobs on your pull request Merged open source release notes: autograd release notes category

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants