-
Notifications
You must be signed in to change notification settings - Fork 25.7k
Description
🐛 Describe the bug
Issue description
If the SobolEngine is initialized after torch.set_default_type(torch.float64), SobolEngine.draw ignores dtype argument and instead returns samples with the default dtype .
Works as expected with torch.set_default_type(torch.float32)
>>> import torch
>>> from torch.quasirandom import SobolEngine
>>> torch.set_default_dtype(torch.float32)
>>> sobol_engine = SobolEngine(dimension=10, scramble=True, seed=0)
>>> sobol_engine.draw(n=5, dtype=torch.float32).dtype
torch.float32
>>> sobol_engine.draw(n=5, dtype=torch.float64).dtype
torch.float64Still works as expected if we update the default dtype but keep the previous SobolEngine instance
>>> torch.set_default_dtype(torch.float64)
>>> sobol_engine.draw(n=5, dtype=torch.float32).dtype
torch.float32
>>> sobol_engine.draw(n=5, dtype=torch.float64).dtype
torch.float64dtype is ignored if SobolEngine is initialized after torch.set_default_type(torch.float64)
>>> sobol_engine = SobolEngine(dimension=10, scramble=True, seed=0)
>>> sobol_engine.draw(n=5, dtype=torch.float32).dtype
torch.float64Expected behavior
SobolEngine(...).draw(n=n, dtype=dtype) should always produce samples with the provided dtype.
Other proposed improvements
Currently, the dtype argument defaults to torch.float32.
class SobolEngine:
...
def draw(self, n: int = 1, out: Optional[torch.Tensor] = None,
dtype: torch.dtype = torch.float32) -> torch.Tensor:
...We can update default argument to None and produce samples with dtype=torch.get_default_dtype().
class SobolEngine:
...
def draw(self, n: int = 1, out: Optional[torch.Tensor] = None,
dtype: Optional[torch.dtype] = None) -> torch.Tensor:
...Versions
PyTorch version: 2.3.0
Is debug build: False
CUDA used to build PyTorch: None
ROCM used to build PyTorch: N/A
OS: macOS 14.4.1 (arm64)
GCC version: Could not collect
Clang version: 15.0.0 (clang-1500.3.9.4)
CMake version: Could not collect
Libc version: N/A
Python version: 3.10.14 (main, May 6 2024, 14:42:37) [Clang 14.0.6 ] (64-bit runtime)
Python platform: macOS-14.4.1-arm64-arm-64bit
Is CUDA available: False
CUDA runtime version: No CUDA
CUDA_MODULE_LOADING set to: N/A
GPU models and configuration: No CUDA
Nvidia driver version: No CUDA
cuDNN version: No CUDA
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True
CPU:
Apple M1 Pro
Versions of relevant libraries:
[pip3] torch==2.3.0
[conda] pytorch 2.3.0 py3.10_0 pytorch