KEMBAR78
Fix LBFGS warning convert a tensor with requires_grad=True to a scalar by zeshengzong · Pull Request #160389 · pytorch/pytorch · GitHub
Skip to content

Conversation

@zeshengzong
Copy link
Contributor

@zeshengzong zeshengzong commented Aug 12, 2025

Fixes #160197

Test Result

In [1]: import warnings
   ...: warnings.simplefilter('error')
   ...: import torch
   ...: print(torch.__version__)
   ...: a, b = torch.rand((2, 32, 32))
   ...: a.requires_grad_()
   ...: optimizer = torch.optim.LBFGS([a])
   ...: loss_fn = lambda x, y: (x-y).pow(2).mean()
   ...: 
   ...: def closure():
   ...:     optimizer.zero_grad()
   ...:     loss = loss_fn(a, b)
   ...:     loss.backward()
   ...:     return loss
   ...: 
   ...: for i in range(100):
   ...:     optimizer.step(closure)
   ...:     print(i, loss_fn(a, b))
   ...: 
2.9.0a0+gitf33f3f8
0 tensor(5.8066e-11, grad_fn=<MeanBackward0>)
1 tensor(5.8066e-11, grad_fn=<MeanBackward0>)
2 tensor(5.8066e-11, grad_fn=<MeanBackward0>)
3 tensor(5.8066e-11, grad_fn=<MeanBackward0>)
4 tensor(5.8066e-11, grad_fn=<MeanBackward0>)
5 tensor(5.8066e-11, grad_fn=<MeanBackward0>)
6 tensor(5.8066e-11, grad_fn=<MeanBackward0>)
7 tensor(5.8066e-11, grad_fn=<MeanBackward0>)
8 tensor(5.8066e-11, grad_fn=<MeanBackward0>)
9 tensor(5.8066e-11, grad_fn=<MeanBackward0>)
10 tensor(5.8066e-11, grad_fn=<MeanBackward0>)

...
pytest test/test_optim.py -vv

...
test/test_optim.py::TestOptimRenewedCUDA::test_tensor_lr_num_dim_2_NAdam_cuda_float32 PASSED [2.7192s]                                                                                                                                           [ 99%]
test/test_optim.py::TestOptimRenewedCUDA::test_tensor_lr_num_dim_2_RAdam_cuda_float32 PASSED [2.5370s]                                                                                                                                           [ 99%]
test/test_optim.py::TestOptimRenewedCUDA::test_tensor_lr_num_dim_2_RMSprop_cuda_float32 PASSED [2.0190s]                                                                                                                                         [ 99%]
test/test_optim.py::TestOptimRenewedCUDA::test_tensor_lr_num_dim_2_Rprop_cuda_float32 PASSED [1.8554s]                                                                                                                                           [ 99%]
test/test_optim.py::TestOptimRenewedCUDA::test_tensor_lr_num_dim_2_SGD_cuda_float32 PASSED [2.0433s]                                                                                                                                             [ 99%]
test/test_optim.py::TestOptimRenewedCUDA::test_tensor_lr_num_dim_2_SparseAdam_cuda_float32 PASSED [1.1788s]                                                                                                                                      [100%]

================== 1471 passed, 242 skipped in 2440.52s (0:40:40) ============

@pytorch-bot
Copy link

pytorch-bot bot commented Aug 12, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/160389

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit b9b291c with merge base 74280d0 (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@zeshengzong zeshengzong marked this pull request as ready for review August 13, 2025 09:00
Copy link
Collaborator

@albanD albanD left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This will be simpler :)

@cpuhrsch cpuhrsch added the triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module label Aug 14, 2025
@janeyx99
Copy link
Contributor

Just checking to make sure I know what's going on--is the before state that those tests would print warnings?

@zeshengzong
Copy link
Contributor Author

Just checking to make sure I know what's going on--is the before state that those tests would print warnings?

I got result like this, it raise UserWarning and stop executing:

In [1]: import warnings
   ...: warnings.simplefilter('error')
   ...: import torch
   ...: print(torch.__version__)
   ...: a, b = torch.rand((2, 32, 32))
   ...: a.requires_grad_()
   ...: optimizer = torch.optim.LBFGS([a])
   ...: loss_fn = lambda x, y: (x-y).pow(2).mean()
   ...: 
   ...: def closure():
   ...:     optimizer.zero_grad()
   ...:     loss = loss_fn(a, b)
   ...:     loss.backward()
   ...:     return loss
   ...: 
   ...: for i in range(100):
   ...:     optimizer.step(closure)
   ...:     print(i, loss_fn(a, b))
   ...: 
2.9.0a0+git23b0334
---------------------------------------------------------------------------
UserWarning                               Traceback (most recent call last)
Cell In[1], line 17
     14     return loss
     16 for i in range(100):
---> 17     optimizer.step(closure)
     18     print(i, loss_fn(a, b))

File ~/code/pytorch/torch/optim/optimizer.py:516, in Optimizer.profile_hook_step.<locals>.wrapper(*args, **kwargs)
    511         else:
    512             raise RuntimeError(
    513                 f"{func} must return None or a tuple of (new_args, new_kwargs), but got {result}."
    514             )
--> 516 out = func(*args, **kwargs)
    517 self._optimizer_step_code()
    519 # call optimizer step post hooks

File ~/code/pytorch/torch/utils/_contextlib.py:120, in context_decorator.<locals>.decorate_context(*args, **kwargs)
    117 @functools.wraps(func)
    118 def decorate_context(*args, **kwargs):
    119     with ctx_factory():
--> 120         return func(*args, **kwargs)

File ~/code/pytorch/torch/optim/lbfgs.py:462, in LBFGS.step(self, closure, zero_grad)
    457 if n_iter != max_iter:
    458     # re-evaluate function only if not in last iteration
    459     # the reason we do this: in a stochastic setting,
    460     # no use to re-evaluate that function here
    461     with torch.enable_grad():
--> 462         loss = float(closure())
    463     flat_grad = self._gather_flat_grad()
    464     opt_cond = flat_grad.abs().max() <= tolerance_grad

UserWarning: Converting a tensor with requires_grad=True to a scalar may lead to unexpected behavior.
Consider using tensor.detach() first. (Triggered internally at /home/coder/code/pytorch/torch/csrc/autograd/generated/python_variable_methods.cpp:836.)

Call item() will make it works:

In [2]: import warnings
   ...: warnings.simplefilter('error')
   ...: import torch
   ...: print(torch.__version__)
   ...: a, b = torch.rand((2, 32, 32))
   ...: a.requires_grad_()
   ...: optimizer = torch.optim.LBFGS([a])
   ...: loss_fn = lambda x, y: (x-y).pow(2).mean()
   ...: 
   ...: def closure():
   ...:     optimizer.zero_grad()
   ...:     loss = loss_fn(a, b)
   ...:     loss.backward()
   ...:     return loss.item()
   ...: 
   ...: for i in range(100):
   ...:     optimizer.step(closure)
   ...:     print(i, loss_fn(a, b))
   ...: 
2.9.0a0+git23b0334
0 tensor(5.6172e-11, grad_fn=<MeanBackward0>)
1 tensor(5.6172e-11, grad_fn=<MeanBackward0>)
2 tensor(5.6172e-11, grad_fn=<MeanBackward0>)
3 tensor(5.6172e-11, grad_fn=<MeanBackward0>)
4 tensor(5.6172e-11, grad_fn=<MeanBackward0>)
5 tensor(5.6172e-11, grad_fn=<MeanBackward0>)
6 tensor(5.6172e-11, grad_fn=<MeanBackward0>)
7 tensor(5.6172e-11, grad_fn=<MeanBackward0>)
...

@janeyx99
Copy link
Contributor

@pytorchbot merge

@pytorch-bot pytorch-bot bot added the ciflow/trunk Trigger trunk jobs on your pull request label Aug 22, 2025
@janeyx99
Copy link
Contributor

@pytorchbot merge -r

@pytorchmergebot
Copy link
Collaborator

@pytorchbot started a rebase job onto refs/remotes/origin/viable/strict. Check the current status here

@pytorchmergebot
Copy link
Collaborator

Successfully rebased opt/optim/lbfgs onto refs/remotes/origin/viable/strict, please pull locally before adding more changes (for example, via git checkout opt/optim/lbfgs && git pull --rebase)

@pytorch-bot pytorch-bot bot removed the ciflow/trunk Trigger trunk jobs on your pull request label Aug 22, 2025
@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

@pytorchmergebot
Copy link
Collaborator

Merge failed

Reason: 3 mandatory check(s) failed. The first few are:

Dig deeper by viewing the failures on hud

Details for Dev Infra team Raised by workflow job

Failing merge rule: Core Maintainers

@janeyx99
Copy link
Contributor

@pytorchbot merge

@pytorch-bot pytorch-bot bot added the ciflow/trunk Trigger trunk jobs on your pull request label Aug 22, 2025
@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

@pytorchmergebot
Copy link
Collaborator

Merge failed

Reason: 1 mandatory check(s) failed. The first few are:

Dig deeper by viewing the failures on hud

Details for Dev Infra team Raised by workflow job

Failing merge rule: Core Maintainers

@zeshengzong
Copy link
Contributor Author

@pytorchbot rebase -b main

@pytorchmergebot
Copy link
Collaborator

@pytorchbot started a rebase job onto refs/remotes/origin/main. Check the current status here

@pytorchmergebot
Copy link
Collaborator

Successfully rebased opt/optim/lbfgs onto refs/remotes/origin/main, please pull locally before adding more changes (for example, via git checkout opt/optim/lbfgs && git pull --rebase)

@pytorch-bot pytorch-bot bot removed the ciflow/trunk Trigger trunk jobs on your pull request label Aug 25, 2025
@zeshengzong
Copy link
Contributor Author

@pytorchbot merge

@pytorch-bot pytorch-bot bot added the ciflow/trunk Trigger trunk jobs on your pull request label Aug 25, 2025
@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

@pytorchmergebot
Copy link
Collaborator

Merge failed

Reason: 1 jobs have failed, first few of them are: trunk / macos-py3-arm64 / test (mps, 1, 1, macos-m2-15)

Details for Dev Infra team Raised by workflow job

@zeshengzong
Copy link
Contributor Author

@pytorchbot merge

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

markc-614 pushed a commit to markc-614/pytorch that referenced this pull request Sep 17, 2025
pytorch#160389)

Fixes pytorch#160197

## Test Result

```python
In [1]: import warnings
   ...: warnings.simplefilter('error')
   ...: import torch
   ...: print(torch.__version__)
   ...: a, b = torch.rand((2, 32, 32))
   ...: a.requires_grad_()
   ...: optimizer = torch.optim.LBFGS([a])
   ...: loss_fn = lambda x, y: (x-y).pow(2).mean()
   ...:
   ...: def closure():
   ...:     optimizer.zero_grad()
   ...:     loss = loss_fn(a, b)
   ...:     loss.backward()
   ...:     return loss
   ...:
   ...: for i in range(100):
   ...:     optimizer.step(closure)
   ...:     print(i, loss_fn(a, b))
   ...:
2.9.0a0+gitf33f3f8
0 tensor(5.8066e-11, grad_fn=<MeanBackward0>)
1 tensor(5.8066e-11, grad_fn=<MeanBackward0>)
2 tensor(5.8066e-11, grad_fn=<MeanBackward0>)
3 tensor(5.8066e-11, grad_fn=<MeanBackward0>)
4 tensor(5.8066e-11, grad_fn=<MeanBackward0>)
5 tensor(5.8066e-11, grad_fn=<MeanBackward0>)
6 tensor(5.8066e-11, grad_fn=<MeanBackward0>)
7 tensor(5.8066e-11, grad_fn=<MeanBackward0>)
8 tensor(5.8066e-11, grad_fn=<MeanBackward0>)
9 tensor(5.8066e-11, grad_fn=<MeanBackward0>)
10 tensor(5.8066e-11, grad_fn=<MeanBackward0>)

...

```

```bash
pytest test/test_optim.py -vv

...
test/test_optim.py::TestOptimRenewedCUDA::test_tensor_lr_num_dim_2_NAdam_cuda_float32 PASSED [2.7192s]                                                                                                                                           [ 99%]
test/test_optim.py::TestOptimRenewedCUDA::test_tensor_lr_num_dim_2_RAdam_cuda_float32 PASSED [2.5370s]                                                                                                                                           [ 99%]
test/test_optim.py::TestOptimRenewedCUDA::test_tensor_lr_num_dim_2_RMSprop_cuda_float32 PASSED [2.0190s]                                                                                                                                         [ 99%]
test/test_optim.py::TestOptimRenewedCUDA::test_tensor_lr_num_dim_2_Rprop_cuda_float32 PASSED [1.8554s]                                                                                                                                           [ 99%]
test/test_optim.py::TestOptimRenewedCUDA::test_tensor_lr_num_dim_2_SGD_cuda_float32 PASSED [2.0433s]                                                                                                                                             [ 99%]
test/test_optim.py::TestOptimRenewedCUDA::test_tensor_lr_num_dim_2_SparseAdam_cuda_float32 PASSED [1.1788s]                                                                                                                                      [100%]

================== 1471 passed, 242 skipped in 2440.52s (0:40:40) ============
```
Pull Request resolved: pytorch#160389
Approved by: https://github.com/janeyx99

Co-authored-by: albanD <desmaison.alban@gmail.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ciflow/trunk Trigger trunk jobs on your pull request Merged open source release notes: optim triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Projects

None yet

Development

Successfully merging this pull request may close these issues.

LBFGS always raises warning about converting a tensor with requires_grad=True to a scalar?

6 participants