-
Notifications
You must be signed in to change notification settings - Fork 25.7k
Description
🐛 Describe the bug
The following code uses both torch.compile and triton.jit. It raises error during compilation. From the generated code, I can see that constexpr is not qualified as tl.constexpr and also it uses square brackets instead of parentheses. It seems like that torch.compile generates the code using repr() because I can workaround this bug by changing tl.constexpr.__repr__:
def __repr__(self) -> str:
- return f"constexpr[{self.value}]"
+ return f"tl.constexpr({self.value})"This problem exists for PyTorch 2.4, 2.5, as well as nightly (2.6.0.dev20241021+cu121).
I saw there's an issue (#136504) related to constexpr, but seems that it's not the same as this one.
Code
import torch
import triton
import triton.language as tl
def _dtype_min_max(dtype: torch.dtype) -> tuple[tl.constexpr, tl.constexpr]:
info = torch.finfo(dtype)
return tl.constexpr(info.min), tl.constexpr(info.max)
_FLOAT8_MIN, _FLOAT8_MAX = _dtype_min_max(torch.float8_e4m3fn)
_FLOAT16_MIN, _FLOAT16_MAX = _dtype_min_max(torch.float16)
_BFLOAT16_MIN, _BFLOAT16_MAX = _dtype_min_max(torch.bfloat16)
@triton.jit
def scale_and_clamp(x, scale, dtype):
if dtype == tl.float8e4nv:
clamp_min = _FLOAT8_MIN
clamp_max = _FLOAT8_MAX
elif dtype == tl.float16:
clamp_min = _FLOAT16_MIN
clamp_max = _FLOAT16_MAX
elif dtype == tl.bfloat16:
clamp_min = _BFLOAT16_MIN
clamp_max = _BFLOAT16_MAX
else:
tl.static_assert(False, f"Unsupported dtype: {dtype}")
return tl.clamp(x * scale, clamp_min, clamp_max).to(dtype)
@triton.jit
def _scaled_cast_kernel(
o_ptr,
x_ptr,
scale_ptr,
d,
BLOCK_SIZE: tl.constexpr,
):
i = tl.program_id(axis=0)
offsets = i * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
mask = offsets < d
scale = tl.load(scale_ptr)
x = tl.load(x_ptr + offsets, mask=mask).to(tl.float32)
result = scale_and_clamp(x, scale, o_ptr.dtype.element_ty)
tl.store(o_ptr + offsets, result, mask=mask)
def scaled_cast(
x: torch.Tensor,
dtype: torch.dtype,
scale: torch.Tensor,
) -> torch.Tensor:
"""Scaled type conversion kernel."""
d = x.numel()
o = torch.empty((d,), dtype=dtype, device=x.device)
def grid(meta):
return (triton.cdiv(d, meta["BLOCK_SIZE"]),)
_scaled_cast_kernel[grid](
o_ptr=o,
x_ptr=x,
scale_ptr=scale,
d=d,
BLOCK_SIZE=1024,
)
return o.reshape(x.shape)
def test_scaled_cast() -> None:
device = torch.device(0)
x = torch.rand((8,), dtype=torch.float16, device=device)
scale = torch.tensor(8, dtype=torch.float32, device=device)
scaled_cast_fn = torch.compile(scaled_cast, fullgraph=True)
result = scaled_cast_fn(x, torch.float8_e4m3fn, scale)
torch.testing.assert_close(
x.to(torch.float32),
result.to(torch.float32) * scale.reciprocal(),
atol=1e-1,
rtol=1e-3,
)
if __name__ == "__main__":
test_scaled_cast()Error trace
Traceback (most recent call last):
File "2024-10-21-torch-compile-triton-constexpr.py", line 93, in <module>
test_scaled_cast()
File "2024-10-21-torch-compile-triton-constexpr.py", line 83, in test_scaled_cast
result = scaled_cast_fn(x, torch.float8_e4m3fn, scale)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "$HOME/miniforge3/envs/nightly/lib/python3.12/site-packages/torch/_dynamo/eval_frame.py", line 550, in _fn
return fn(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^
File "$HOME/miniforge3/envs/nightly/lib/python3.12/site-packages/torch/_dynamo/convert_frame.py", line 1364, in __call__
return self._torchdynamo_orig_callable(
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "$HOME/miniforge3/envs/nightly/lib/python3.12/site-packages/torch/_dynamo/convert_frame.py", line 544, in __call__
return _compile(
^^^^^^^^^
File "$HOME/miniforge3/envs/nightly/lib/python3.12/site-packages/torch/_dynamo/convert_frame.py", line 964, in _compile
guarded_code = compile_inner(code, one_graph, hooks, transform)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "$HOME/miniforge3/envs/nightly/lib/python3.12/site-packages/torch/_dynamo/convert_frame.py", line 695, in compile_inner
return _compile_inner(code, one_graph, hooks, transform)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "$HOME/miniforge3/envs/nightly/lib/python3.12/site-packages/torch/_utils_internal.py", line 87, in wrapper_function
return function(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^
File "$HOME/miniforge3/envs/nightly/lib/python3.12/site-packages/torch/_dynamo/convert_frame.py", line 728, in _compile_inner
out_code = transform_code_object(code, transform)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "$HOME/miniforge3/envs/nightly/lib/python3.12/site-packages/torch/_dynamo/bytecode_transformation.py", line 1337, in transform_code_object
transformations(instructions, code_options)
File "$HOME/miniforge3/envs/nightly/lib/python3.12/site-packages/torch/_dynamo/convert_frame.py", line 229, in _fn
return fn(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^
File "$HOME/miniforge3/envs/nightly/lib/python3.12/site-packages/torch/_dynamo/convert_frame.py", line 657, in transform
tracer.run()
File "$HOME/miniforge3/envs/nightly/lib/python3.12/site-packages/torch/_dynamo/symbolic_convert.py", line 2887, in run
super().run()
File "$HOME/miniforge3/envs/nightly/lib/python3.12/site-packages/torch/_dynamo/symbolic_convert.py", line 1095, in run
while self.step():
^^^^^^^^^^^
File "$HOME/miniforge3/envs/nightly/lib/python3.12/site-packages/torch/_dynamo/symbolic_convert.py", line 1007, in step
self.dispatch_table[inst.opcode](self, inst)
File "$HOME/miniforge3/envs/nightly/lib/python3.12/site-packages/torch/_dynamo/symbolic_convert.py", line 3078, in RETURN_VALUE
self._return(inst)
File "$HOME/miniforge3/envs/nightly/lib/python3.12/site-packages/torch/_dynamo/symbolic_convert.py", line 3063, in _return
self.output.compile_subgraph(
File "$HOME/miniforge3/envs/nightly/lib/python3.12/site-packages/torch/_dynamo/output_graph.py", line 1131, in compile_subgraph
self.compile_and_call_fx_graph(tx, list(reversed(stack_values)), root)
File "$HOME/miniforge3/envs/nightly/lib/python3.12/site-packages/torch/_dynamo/output_graph.py", line 1400, in compile_and_call_fx_graph
compiled_fn = self.call_user_compiler(gm)
^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "$HOME/miniforge3/envs/nightly/lib/python3.12/site-packages/torch/_dynamo/output_graph.py", line 1447, in call_user_compiler
return self._call_user_compiler(gm)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "$HOME/miniforge3/envs/nightly/lib/python3.12/site-packages/torch/_dynamo/output_graph.py", line 1496, in _call_user_compiler
raise BackendCompilerFailed(self.compiler_fn, e).with_traceback(
File "$HOME/miniforge3/envs/nightly/lib/python3.12/site-packages/torch/_dynamo/output_graph.py", line 1477, in _call_user_compiler
compiled_fn = compiler_fn(gm, self.example_inputs())
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "$HOME/miniforge3/envs/nightly/lib/python3.12/site-packages/torch/_dynamo/repro/after_dynamo.py", line 130, in __call__
compiled_gm = compiler_fn(gm, example_inputs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "$HOME/miniforge3/envs/nightly/lib/python3.12/site-packages/torch/_dynamo/repro/after_dynamo.py", line 130, in __call__
compiled_gm = compiler_fn(gm, example_inputs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "$HOME/miniforge3/envs/nightly/lib/python3.12/site-packages/torch/__init__.py", line 2249, in __call__
return compile_fx(model_, inputs_, config_patches=self.config)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "$HOME/miniforge3/envs/nightly/lib/python3.12/site-packages/torch/_inductor/compile_fx.py", line 1605, in compile_fx
return aot_autograd(
^^^^^^^^^^^^^
File "$HOME/miniforge3/envs/nightly/lib/python3.12/site-packages/torch/_dynamo/backends/common.py", line 72, in __call__
cg = aot_module_simplified(gm, example_inputs, **self.kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "$HOME/miniforge3/envs/nightly/lib/python3.12/site-packages/torch/_functorch/aot_autograd.py", line 1087, in aot_module_simplified
compiled_fn = dispatch_and_compile()
^^^^^^^^^^^^^^^^^^^^^^
File "$HOME/miniforge3/envs/nightly/lib/python3.12/site-packages/torch/_functorch/aot_autograd.py", line 1063, in dispatch_and_compile
compiled_fn, _ = create_aot_dispatcher_function(
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "$HOME/miniforge3/envs/nightly/lib/python3.12/site-packages/torch/_functorch/aot_autograd.py", line 524, in create_aot_dispatcher_function
return _create_aot_dispatcher_function(
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "$HOME/miniforge3/envs/nightly/lib/python3.12/site-packages/torch/_functorch/aot_autograd.py", line 762, in _create_aot_dispatcher_function
compiled_fn, fw_metadata = compiler_fn(
^^^^^^^^^^^^
File "$HOME/miniforge3/envs/nightly/lib/python3.12/site-packages/torch/_functorch/_aot_autograd/jit_compile_runtime_wrappers.py", line 196, in aot_dispatch_base
compiled_fw = compiler(fw_module, updated_flat_args)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "$HOME/miniforge3/envs/nightly/lib/python3.12/site-packages/torch/_inductor/compile_fx.py", line 1428, in fw_compiler_base
return _fw_compiler_base(model, example_inputs, is_inference)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "$HOME/miniforge3/envs/nightly/lib/python3.12/site-packages/torch/_inductor/compile_fx.py", line 1499, in _fw_compiler_base
return inner_compile(
^^^^^^^^^^^^^^
File "$HOME/miniforge3/envs/nightly/lib/python3.12/site-packages/torch/_inductor/compile_fx.py", line 474, in compile_fx_inner
return wrap_compiler_debug(_compile_fx_inner, compiler_name="inductor")(
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "$HOME/miniforge3/envs/nightly/lib/python3.12/site-packages/torch/_dynamo/repro/after_aot.py", line 85, in debug_wrapper
inner_compiled_fn = compiler_fn(gm, example_inputs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "$HOME/miniforge3/envs/nightly/lib/python3.12/site-packages/torch/_inductor/compile_fx.py", line 658, in _compile_fx_inner
compiled_graph = FxGraphCache.load(
^^^^^^^^^^^^^^^^^^
File "$HOME/miniforge3/envs/nightly/lib/python3.12/site-packages/torch/_inductor/codecache.py", line 1462, in load
compiled_graph = compile_fx_fn(
^^^^^^^^^^^^^^
File "$HOME/miniforge3/envs/nightly/lib/python3.12/site-packages/torch/_inductor/compile_fx.py", line 569, in codegen_and_compile
compiled_graph = fx_codegen_and_compile(gm, example_inputs, **fx_kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "$HOME/miniforge3/envs/nightly/lib/python3.12/site-packages/torch/_inductor/compile_fx.py", line 879, in fx_codegen_and_compile
compiled_fn = graph.compile_to_fn()
^^^^^^^^^^^^^^^^^^^^^
File "$HOME/miniforge3/envs/nightly/lib/python3.12/site-packages/torch/_inductor/graph.py", line 2013, in compile_to_fn
return self.compile_to_module().call
^^^^^^^^^^^^^^^^^^^^^^^^
File "$HOME/miniforge3/envs/nightly/lib/python3.12/site-packages/torch/_inductor/graph.py", line 1935, in compile_to_module
return self._compile_to_module()
^^^^^^^^^^^^^^^^^^^^^^^^^
File "$HOME/miniforge3/envs/nightly/lib/python3.12/site-packages/torch/_inductor/graph.py", line 1967, in _compile_to_module
mod = PyCodeCache.load_by_key_path(
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "$HOME/miniforge3/envs/nightly/lib/python3.12/site-packages/torch/_inductor/codecache.py", line 3005, in load_by_key_path
mod = _reload_python_module(key, path)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "$HOME/miniforge3/envs/nightly/lib/python3.12/site-packages/torch/_inductor/runtime/compile_tasks.py", line 45, in _reload_python_module
exec(code, mod.__dict__, mod.__dict__)
File "/tmp/torchinductor_lequn/xu/cxukmmra2p6rp6v4w5wwwcv5mhenjyvxa5eebzzdqk4lfrogsd5e.py", line 34, in <module>
_scaled_cast_kernel_0 = async_compile.triton('_scaled_cast_kernel', '''
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "$HOME/miniforge3/envs/nightly/lib/python3.12/site-packages/torch/_inductor/async_compile.py", line 208, in triton
kernel = TritonCodeCache.load(kernel_name, source_code)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "$HOME/miniforge3/envs/nightly/lib/python3.12/site-packages/torch/_inductor/codecache.py", line 3055, in load
return _module_to_triton_kernel(PyCodeCache.load(source_code), kernel_name)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "$HOME/miniforge3/envs/nightly/lib/python3.12/site-packages/torch/_inductor/codecache.py", line 2992, in load
return cls.load_by_key_path(key, path, linemap, attrs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "$HOME/miniforge3/envs/nightly/lib/python3.12/site-packages/torch/_inductor/codecache.py", line 3005, in load_by_key_path
mod = _reload_python_module(key, path)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "$HOME/miniforge3/envs/nightly/lib/python3.12/site-packages/torch/_inductor/runtime/compile_tasks.py", line 45, in _reload_python_module
exec(code, mod.__dict__, mod.__dict__)
File "/tmp/torchinductor_lequn/do/cdozu2ltlyjdfoq54heh3oh3c6egepd4luyqnxx4b6bj4fxo5w2s.py", line 51, in <module>
_FLOAT8_MIN = constexpr[-448.0]
^^^^^^^^^
torch._dynamo.exc.BackendCompilerFailed: backend='inductor' raised:
NameError: name 'constexpr' is not defined
Set TORCH_LOGS="+dynamo" and TORCHDYNAMO_VERBOSE=1 for more information
You can suppress this exception and fall back to eager by setting:
import torch._dynamo
torch._dynamo.config.suppress_errors = True
Generated code
import triton
import triton.language as tl
from triton.compiler.compiler import AttrsDescriptor
from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
@triton_heuristics.user_autotune(
configs=[],
inductor_meta={'kernel_name': '_scaled_cast_kernel_0', 'backend_hash': '369B60C25BDE3BDB2552679A82B7C6B5364D80AF90BAFCF99210C0F49B749E9C', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False},
triton_meta={'signature': {'o_ptr': '*fp8e4nv', 'x_ptr': '*fp16', 'scale_ptr': '*fp32', 'd': 'i32'}, 'device': DeviceProperties(type='cuda', index=0, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, multi_processor_count=132, warp_size=32), 'constants': {'BLOCK_SIZE': 1024}, 'configs': [AttrsDescriptor(divisible_by_16=(0, 1, 2), equal_to_1=())]},
filename=__file__,
custom_kernel=True,
)
@triton.jit
def _scaled_cast_kernel(
o_ptr,
x_ptr,
scale_ptr,
d,
BLOCK_SIZE: tl.constexpr,
):
i = tl.program_id(axis=0)
offsets = i * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
mask = offsets < d
scale = tl.load(scale_ptr)
x = tl.load(x_ptr + offsets, mask=mask).to(tl.float32)
result = scale_and_clamp(x, scale, o_ptr.dtype.element_ty)
tl.store(o_ptr + offsets, result, mask=mask)
@triton.jit
def scale_and_clamp(x, scale, dtype):
if dtype == tl.float8e4nv:
clamp_min = _FLOAT8_MIN
clamp_max = _FLOAT8_MAX
elif dtype == tl.float16:
clamp_min = _FLOAT16_MIN
clamp_max = _FLOAT16_MAX
elif dtype == tl.bfloat16:
clamp_min = _BFLOAT16_MIN
clamp_max = _BFLOAT16_MAX
else:
tl.static_assert(False, f"Unsupported dtype: {dtype}")
return tl.clamp(x * scale, clamp_min, clamp_max).to(dtype)
_FLOAT8_MIN = constexpr[-448.0]
_FLOAT8_MAX = constexpr[448.0]
_FLOAT16_MIN = constexpr[-65504.0]
_FLOAT16_MAX = constexpr[65504.0]
_BFLOAT16_MIN = constexpr[-3.3895313892515355e+38]
_BFLOAT16_MAX = constexpr[3.3895313892515355e+38]Versions
Collecting environment information...
PyTorch version: 2.6.0.dev20241021+cu121
Is debug build: False
CUDA used to build PyTorch: 12.1
ROCM used to build PyTorch: N/A
OS: Ubuntu 20.04.6 LTS (x86_64)
GCC version: (Ubuntu 9.4.0-1ubuntu1~20.04.2) 9.4.0
Clang version: Could not collect
CMake version: version 3.29.3
Libc version: glibc-2.31
Python version: 3.12.7 | packaged by conda-forge | (main, Oct 4 2024, 16:05:46) [GCC 13.3.0] (64-bit runtime)
Python platform: Linux-5.15.0-1062-aws-x86_64-with-glibc2.31
Is CUDA available: True
CUDA runtime version: 12.1.105
CUDA_MODULE_LOADING set to: LAZY
GPU models and configuration:
GPU 0: NVIDIA H100 80GB HBM3
GPU 1: NVIDIA H100 80GB HBM3
GPU 2: NVIDIA H100 80GB HBM3
GPU 3: NVIDIA H100 80GB HBM3
GPU 4: NVIDIA H100 80GB HBM3
GPU 5: NVIDIA H100 80GB HBM3
GPU 6: NVIDIA H100 80GB HBM3
GPU 7: NVIDIA H100 80GB HBM3
Nvidia driver version: 535.183.01
cuDNN version: Could not collect
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True
CPU:
Architecture: x86_64
CPU op-mode(s): 32-bit, 64-bit
Byte Order: Little Endian
Address sizes: 48 bits physical, 48 bits virtual
CPU(s): 96
On-line CPU(s) list: 0-95
Thread(s) per core: 1
Core(s) per socket: 48
Socket(s): 2
NUMA node(s): 2
Vendor ID: AuthenticAMD
CPU family: 25
Model: 1
Model name: AMD EPYC 7R13 Processor
Stepping: 1
CPU MHz: 1840.200
BogoMIPS: 5299.99
Hypervisor vendor: KVM
Virtualization type: full
L1d cache: 3 MiB
L1i cache: 3 MiB
L2 cache: 48 MiB
L3 cache: 384 MiB
NUMA node0 CPU(s): 0-47
NUMA node1 CPU(s): 48-95
Vulnerability Gather data sampling: Not affected
Vulnerability Itlb multihit: Not affected
Vulnerability L1tf: Not affected
Vulnerability Mds: Not affected
Vulnerability Meltdown: Not affected
Vulnerability Mmio stale data: Not affected
Vulnerability Retbleed: Not affected
Vulnerability Spec rstack overflow: Mitigation; safe RET, no microcode
Vulnerability Spec store bypass: Mitigation; Speculative Store Bypass disabled via prctl and seccomp
Vulnerability Spectre v1: Mitigation; usercopy/swapgs barriers and __user pointer sanitization
Vulnerability Spectre v2: Mitigation; Retpolines; IBPB conditional; IBRS_FW; STIBP always-on; RSB filling; PBRSB-eIBRS Not affected; BHI Not affected
Vulnerability Srbds: Not affected
Vulnerability Tsx async abort: Not affected
Flags: fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush mmx fxsr sse sse2 ht syscall nx mmxext fxsr_opt pdpe1gb rdtscp lm constant_tsc rep_good nopl nonstop_tsc cpuid extd_apicid aperfmperf tsc_known_freq pni pclmulqdq monitor ssse3 fma cx16 pcid sse4_1 sse4_2 x2apic movbe popcnt aes xsave avx f16c rdrand hypervisor lahf_lm cmp_legacy cr8_legacy abm sse4a misalignsse 3dnowprefetch topoext perfctr_core invpcid_single ssbd ibrs ibpb stibp vmmcall fsgsbase bmi1 avx2 smep bmi2 invpcid rdseed adx smap clflushopt clwb sha_ni xsaveopt xsavec xgetbv1 clzero xsaveerptr rdpru wbnoinvd arat npt nrip_save vaes vpclmulqdq rdpid
Versions of relevant libraries:
[pip3] numpy==2.1.2
[pip3] nvidia-cublas-cu12==12.1.3.1
[pip3] nvidia-cuda-cupti-cu12==12.1.105
[pip3] nvidia-cuda-nvrtc-cu12==12.1.105
[pip3] nvidia-cuda-runtime-cu12==12.1.105
[pip3] nvidia-cudnn-cu12==9.1.0.70
[pip3] nvidia-cufft-cu12==11.0.2.54
[pip3] nvidia-curand-cu12==10.3.2.106
[pip3] nvidia-cusolver-cu12==11.4.5.107
[pip3] nvidia-cusparse-cu12==12.1.0.106
[pip3] nvidia-cusparselt-cu12==0.6.2
[pip3] nvidia-nccl-cu12==2.21.5
[pip3] nvidia-nvjitlink-cu12==12.1.105
[pip3] nvidia-nvtx-cu12==12.1.105
[pip3] pytorch-triton==3.1.0+cf34004b8a
[pip3] torch==2.6.0.dev20241021+cu121
[conda] numpy 2.1.2 pypi_0 pypi
[conda] nvidia-cublas-cu12 12.1.3.1 pypi_0 pypi
[conda] nvidia-cuda-cupti-cu12 12.1.105 pypi_0 pypi
[conda] nvidia-cuda-nvrtc-cu12 12.1.105 pypi_0 pypi
[conda] nvidia-cuda-runtime-cu12 12.1.105 pypi_0 pypi
[conda] nvidia-cudnn-cu12 9.1.0.70 pypi_0 pypi
[conda] nvidia-cufft-cu12 11.0.2.54 pypi_0 pypi
[conda] nvidia-curand-cu12 10.3.2.106 pypi_0 pypi
[conda] nvidia-cusolver-cu12 11.4.5.107 pypi_0 pypi
[conda] nvidia-cusparse-cu12 12.1.0.106 pypi_0 pypi
[conda] nvidia-cusparselt-cu12 0.6.2 pypi_0 pypi
[conda] nvidia-nccl-cu12 2.21.5 pypi_0 pypi
[conda] nvidia-nvjitlink-cu12 12.1.105 pypi_0 pypi
[conda] nvidia-nvtx-cu12 12.1.105 pypi_0 pypi
[conda] pytorch-triton 3.1.0+cf34004b8a pypi_0 pypi
[conda] torch 2.6.0.dev20241021+cu121 pypi_0 pypi
cc @ezyang @chauhang @penguinwu @oulgen @aakhundov @davidberard98