KEMBAR78
`SobolEngine.draw` does not respect the default `dtype` and always uses the passed in dtype (defaulted to float32) · Issue #126478 · pytorch/pytorch · GitHub
Skip to content

SobolEngine.draw does not respect the default dtype and always uses the passed in dtype (defaulted to float32) #126478

@saitcakmak

Description

@saitcakmak

🐛 Describe the bug

Issue description

If the SobolEngine is initialized after torch.set_default_type(torch.float64), SobolEngine.draw ignores dtype argument and instead returns samples with the default dtype .

Works as expected with torch.set_default_type(torch.float32)

>>> import torch
>>> from torch.quasirandom import SobolEngine
>>> torch.set_default_dtype(torch.float32)
>>> sobol_engine = SobolEngine(dimension=10, scramble=True, seed=0)
>>> sobol_engine.draw(n=5, dtype=torch.float32).dtype
torch.float32
>>> sobol_engine.draw(n=5, dtype=torch.float64).dtype
torch.float64

Still works as expected if we update the default dtype but keep the previous SobolEngine instance

>>> torch.set_default_dtype(torch.float64)
>>> sobol_engine.draw(n=5, dtype=torch.float32).dtype
torch.float32
>>> sobol_engine.draw(n=5, dtype=torch.float64).dtype
torch.float64

dtype is ignored if SobolEngine is initialized after torch.set_default_type(torch.float64)

>>> sobol_engine = SobolEngine(dimension=10, scramble=True, seed=0)
>>> sobol_engine.draw(n=5, dtype=torch.float32).dtype
torch.float64

Expected behavior

SobolEngine(...).draw(n=n, dtype=dtype) should always produce samples with the provided dtype.

Other proposed improvements

Currently, the dtype argument defaults to torch.float32.

class SobolEngine:
    ...
    def draw(self, n: int = 1, out: Optional[torch.Tensor] = None,
             dtype: torch.dtype = torch.float32) -> torch.Tensor:
    ...

We can update default argument to None and produce samples with dtype=torch.get_default_dtype().

class SobolEngine:
    ...
    def draw(self, n: int = 1, out: Optional[torch.Tensor] = None,
             dtype: Optional[torch.dtype] = None) -> torch.Tensor:
    ...

Versions

PyTorch version: 2.3.0
Is debug build: False
CUDA used to build PyTorch: None
ROCM used to build PyTorch: N/A

OS: macOS 14.4.1 (arm64)
GCC version: Could not collect
Clang version: 15.0.0 (clang-1500.3.9.4)
CMake version: Could not collect
Libc version: N/A

Python version: 3.10.14 (main, May 6 2024, 14:42:37) [Clang 14.0.6 ] (64-bit runtime)
Python platform: macOS-14.4.1-arm64-arm-64bit
Is CUDA available: False
CUDA runtime version: No CUDA
CUDA_MODULE_LOADING set to: N/A
GPU models and configuration: No CUDA
Nvidia driver version: No CUDA
cuDNN version: No CUDA
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True

CPU:
Apple M1 Pro

Versions of relevant libraries:
[pip3] torch==2.3.0
[conda] pytorch 2.3.0 py3.10_0 pytorch

cc @pbelevich @albanD

Metadata

Metadata

Assignees

Labels

module: python frontendFor issues relating to PyTorch's Python frontendmodule: randomRelated to random number generation in PyTorch (rng generator)triagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate module

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions