-
Notifications
You must be signed in to change notification settings - Fork 25.7k
Closed
Labels
actionablemodule: optimizerRelated to torch.optimRelated to torch.optimtriagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate moduleThis issue has been looked at a team member, and triaged and prioritized into an appropriate module
Description
🐛 Describe the bug
Optim.Adam 'step' default setting bug.
import torch
import torch.nn as nn
import torch.optim as optim
device = "cuda"
dtype = torch.float64
torch.set_default_dtype(torch.float64)
X = torch.randn(10, 1, device=device, dtype=dtype)
y = 2 * X + 1 + 0.1 * torch.randn(10, 1, device=device, dtype=dtype)
class LinearRegression(nn.Module):
def __init__(self, device: str = None):
super(LinearRegression, self).__init__()
self.linear = nn.Linear(1, 1, device=device)
def forward(self, x):
return self.linear(x)
model = LinearRegression(device=device)
criterion = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=0.01)
for epoch in range(10):
outputs = model(X)
loss = criterion(outputs, y)
optimizer.zero_grad()
loss.backward()
optimizer.step()
Traceback (most recent call last):
File "/home/zbwu/Desktop/example/test-1.py", line 32, in <module>
optimizer.step()
File "/home/zbwu/soft/anaconda3/lib/python3.10/site-packages/torch/optim/optimizer.py", line 373, in wrapper
out = func(*args, **kwargs)
File "/home/zbwu/soft/anaconda3/lib/python3.10/site-packages/torch/optim/optimizer.py", line 76, in _use_grad
ret = func(self, *args, **kwargs)
File "/home/zbwu/soft/anaconda3/lib/python3.10/site-packages/torch/optim/adam.py", line 165, in step
adam(
File "/home/zbwu/soft/anaconda3/lib/python3.10/site-packages/torch/optim/adam.py", line 313, in adam
func(params,
File "/home/zbwu/soft/anaconda3/lib/python3.10/site-packages/torch/optim/adam.py", line 476, in _multi_tensor_adam
grouped_tensors = Optimizer._group_tensors_by_device_and_dtype(
File "/home/zbwu/soft/anaconda3/lib/python3.10/site-packages/torch/optim/optimizer.py", line 397, in _group_tensors_by_device_and_dtype
return _group_tensors_by_device_and_dtype(tensorlistlist, with_indices)
File "/home/zbwu/soft/anaconda3/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 115, in decorate_context
return func(*args, **kwargs)
File "/home/zbwu/soft/anaconda3/lib/python3.10/site-packages/torch/utils/_foreach_utils.py", line 47, in _group_tensors_by_device_and_dtype
torch._C._group_tensors_by_device_and_dtype(tensorlistlist, with_indices).items()
RuntimeError: Tensors of the same index must be on the same device and the same dtype except `step` tensors that can be CPU and float32 notwithstanding
The optimizer Adam 'step' dtype is set to float64, which leads to the above error.
If default 'step' is changed to torch.tensor(0., dtype=torch.float) or cancel torch.set_default_dtype(torch.float64) to avoid the above situation.
Versions
Collecting environment information...
PyTorch version: 2.1.0
Is debug build: False
CUDA used to build PyTorch: 11.8
ROCM used to build PyTorch: N/A
OS: Arch Linux (x86_64)
GCC version: (Arch Linux 10.5.0-1) 10.5.0
Clang version: 16.0.6
CMake version: version 3.27.7
Libc version: glibc-2.38
Python version: 3.10.13 (main, Sep 11 2023, 13:44:35) [GCC 11.2.0] (64-bit runtime)
Python platform: Linux-6.1.55-1-lts-x86_64-with-glibc2.38
vadimkantorov and tioans
Metadata
Metadata
Assignees
Labels
actionablemodule: optimizerRelated to torch.optimRelated to torch.optimtriagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate moduleThis issue has been looked at a team member, and triaged and prioritized into an appropriate module