KEMBAR78
torch 2.8 RC regression - part 3 · Issue #158376 · pytorch/pytorch · GitHub
Skip to content

torch 2.8 RC regression - part 3 #158376

@ydshieh

Description

@ydshieh

🐛 Describe the bug

To Reproduce

Running on A10

1. Install torch 2.8 RC:

python3 -m pip uninstall -y torch torchvision torchaudio && python3 -m pip install --no-cache-dir torch torchvision torchaudio --index-url https://download.pytorch.org/whl/test/cu126

2. Install transformers

git clone https://github.com/huggingface/transformers.git && cd transformers && git fetch origin && git checkout 6017f5e8 && pip install -e .[torch,testing]

3. Running test

RUN_SLOW=1 python3 -m pytest -v tests/models/timesfm/test_modeling_timesfm.py::TimesFmModelTest::test_sdpa_can_compile_dynamic

4. Error log

_________________________________________________________________________ TimesFmModelTest.test_sdpa_can_compile_dynamic __________________________________________________________________________

self = <tests.models.timesfm.test_modeling_timesfm.TimesFmModelTest testMethod=test_sdpa_can_compile_dynamic>

    @require_torch_sdpa
    @require_torch_accelerator
    @slow
    def test_sdpa_can_compile_dynamic(self):
        if not self.has_attentions:
            self.skipTest(reason="Model architecture does not support attentions")
    
        device_type, major, minor = get_device_properties()
        if device_type == "cuda" and major < 8:
            self.skipTest(reason="This test requires an NVIDIA GPU with compute capability >= 8.0")
        elif device_type == "rocm" and major < 9:
            self.skipTest(reason="This test requires an AMD GPU with compute capability >= 9.0")
        elif device_type not in ["cuda", "rocm", "xpu"]:
            self.skipTest(reason="This test requires a Nvidia or AMD GPU, or an Intel XPU")

        torch.compiler.reset()

        for model_class in self.all_model_classes:
            if not model_class._supports_sdpa:
                self.skipTest(f"{model_class.__name__} does not support SDPA")

            config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
            inputs_dict = self._prepare_for_class(inputs_dict, model_class)
            if config.model_type in ["dbrx"]:
                self.skipTest(
                    "DBRX (transformers==4.40) requires a modification to support dynamic shapes with compile."
                )
            if getattr(config, "cache_implementation", None) == "hybrid":
                self.skipTest(
                    "Cannot compile forward without an existing cache with Hybrid, as `torch._dynamo.mark_static_address` "
                    "is a forbidden call."
                )

            model = model_class(config)

            sub_models_supporting_sdpa = [
                module._supports_sdpa
                for name, module in model.named_modules()
                if isinstance(module, PreTrainedModel) and name != ""
            ]
            supports_sdpa_all_modules = (
                all(sub_models_supporting_sdpa) if len(sub_models_supporting_sdpa) > 0 else model._supports_sdpa
            )
            if not supports_sdpa_all_modules:
                self.skipTest(reason="This models' submodels does not support sdpa")

            with tempfile.TemporaryDirectory() as tmpdirname:
                model.save_pretrained(tmpdirname)
                model = model_class.from_pretrained(tmpdirname, torch_dtype=torch.float16, attn_implementation="sdpa")
                model.to(torch_device)

                # For PyTorch 2.1 - 2.3.0 set `dynamic=True`. In the future setting `dynamic=None` and using `torch._dynamo.mark_dynamic()`
                # on input tensors will be required. `mark_dynamic` currently raises inconsistent shape errors.
                model = torch.compile(model, dynamic=True)

                inputs_dict.pop("attention_mask", None)
                inputs_dict.pop("decoder_attention_mask", None)
                for name, inp in inputs_dict.items():
                    if isinstance(inp, torch.Tensor) and inp.dtype in [torch.float32, torch.float16]:
                        inputs_dict[name] = inp.to(torch.float16)

                # use no_grad to save some memory
                with torch.no_grad():
>                   _ = model(**inputs_dict)

tests/test_modeling_common.py:3907:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
/usr/local/lib/python3.10/dist-packages/torch/_dynamo/eval_frame.py:375: in __call__
    return super().__call__(*args, **kwargs)
/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py:1773: in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py:1784: in _call_impl
    return forward_call(*args, **kwargs)
/usr/local/lib/python3.10/dist-packages/torch/_dynamo/eval_frame.py:749: in compile_wrapper
    raise e.remove_dynamo_frames() from None  # see TORCHDYNAMO_VERBOSE=1
/usr/local/lib/python3.10/dist-packages/torch/_dynamo/output_graph.py:1871: in _call_user_compiler
    raise BackendCompilerFailed(
/usr/local/lib/python3.10/dist-packages/torch/_dynamo/output_graph.py:1846: in _call_user_compiler
    compiled_fn = compiler_fn(gm, example_inputs)
/usr/local/lib/python3.10/dist-packages/torch/_dynamo/repro/after_dynamo.py:150: in __call__
    compiled_gm = compiler_fn(gm, example_inputs)
/usr/local/lib/python3.10/dist-packages/torch/__init__.py:2381: in __call__
    return compile_fx(model_, inputs_, config_patches=self.config)
/usr/local/lib/python3.10/dist-packages/torch/_inductor/compile_fx.py:2418: in compile_fx
    return aot_autograd(
/usr/local/lib/python3.10/dist-packages/torch/_dynamo/backends/common.py:109: in __call__
    cg = aot_module_simplified(gm, example_inputs, **self.kwargs)
/usr/local/lib/python3.10/dist-packages/torch/_functorch/aot_autograd.py:1199: in aot_module_simplified
    compiled_fn = AOTAutogradCache.load(
/usr/local/lib/python3.10/dist-packages/torch/_functorch/_aot_autograd/autograd_cache.py:1140: in load
    compiled_fn = dispatch_and_compile()
/usr/local/lib/python3.10/dist-packages/torch/_functorch/aot_autograd.py:1184: in dispatch_and_compile
    compiled_fn, _ = create_aot_dispatcher_function(
/usr/local/lib/python3.10/dist-packages/torch/_functorch/aot_autograd.py:576: in create_aot_dispatcher_function
    return _create_aot_dispatcher_function(
/usr/local/lib/python3.10/dist-packages/torch/_functorch/aot_autograd.py:836: in _create_aot_dispatcher_function
    compiled_fn, fw_metadata = compiler_fn(
/usr/local/lib/python3.10/dist-packages/torch/_functorch/_aot_autograd/jit_compile_runtime_wrappers.py:245: in aot_dispatch_base
    compiled_fw = compiler(fw_module, updated_flat_args)
/usr/local/lib/python3.10/dist-packages/torch/_functorch/aot_autograd.py:483: in __call__
    return self.compiler_fn(gm, example_inputs)
/usr/local/lib/python3.10/dist-packages/torch/_inductor/compile_fx.py:2183: in fw_compiler_base
    _recursive_joint_graph_passes(gm)
/usr/local/lib/python3.10/dist-packages/torch/_inductor/compile_fx.py:492: in _recursive_joint_graph_passes
    joint_graph_passes(gm)
/usr/local/lib/python3.10/dist-packages/torch/_inductor/fx_passes/joint_graph.py:587: in joint_graph_passes
    GraphTransformObserver(graph, "constant_fold_uniform_value").apply_gm_pass(
/usr/local/lib/python3.10/dist-packages/torch/fx/passes/graph_transform_observer.py:78: in apply_gm_pass
    return pass_fn(self.gm)
/usr/local/lib/python3.10/dist-packages/torch/_inductor/fx_passes/joint_graph.py:384: in constant_fold_uniform_value
    cf.run()
/usr/local/lib/python3.10/dist-packages/torch/_inductor/constant_folding.py:295: in run
    return super().run(initial_env=env)
/usr/local/lib/python3.10/dist-packages/torch/fx/interpreter.py:173: in run
    self.env[node] = self.run_node(node)
/usr/local/lib/python3.10/dist-packages/torch/_inductor/constant_folding.py:250: in run_node
    out = self._deduce_value(node)
/usr/local/lib/python3.10/dist-packages/torch/_inductor/fx_passes/joint_graph.py:371: in _deduce_value
    return node.target(*args, **kwargs)
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _

self = <OpOverload(op='aten.scalar_tensor', overload='default')>, args = (1123581321.0,), kwargs = {'device': device(type='cpu'), 'dtype': torch.float16, 'pin_memory': False}

    def __call__(self, /, *args: _P.args, **kwargs: _P.kwargs) -> _T:
>       return self._op(*args, **kwargs)
E       torch._dynamo.exc.BackendCompilerFailed: backend='inductor' raised:
E       RuntimeError: value cannot be converted to type at::Half without overflow
E       
E       While executing %scalar_tensor_1 : [num_users=1] = call_function[target=torch.ops.aten.scalar_tensor.default](args = (1123581321.0,), kwargs = {dtype: torch.float16, device: cpu, pin_memory: False})
E       GraphModule: class <lambda>(torch.nn.Module):
E           def forward(self, arg0_1: "Sym(s5)", arg1_1: "f16[3, s5][s5, 1]", arg2_1: "i64[3][1]", arg3_1: "f64[][]", arg4_1: "f64[][]", arg5_1: "f16[32, 64][64, 1]", arg6_1: "f16[32][1]", arg7_1
: "f16[16, 32][32, 1]", arg8_1: "f16[16][1]", arg9_1: "f16[16, 64][64, 1]", arg10_1: "f16[16][1]", arg11_1: "f32[8][1]", arg12_1: "f16[3, 16][16, 1]", arg13_1: "f64[][]", arg14_1: "f16[16][1]", a
rg15_1: "f16[16, 16][16, 1]", arg16_1: "f16[16][1]", arg17_1: "f16[8][1]", arg18_1: "f16[16, 16][16, 1]", arg19_1: "f16[16][1]", arg20_1: "f16[16, 16][16, 1]", arg21_1: "f16[16][1]", arg22_1: "f1
6[16, 16][16, 1]", arg23_1: "f16[16][1]", arg24_1: "f16[16][1]", arg25_1: "f16[16][1]", arg26_1: "f16[32, 16][16, 1]", arg27_1: "f16[32][1]", arg28_1: "f16[16, 32][32, 1]", arg29_1: "f16[16][1]", arg30_1: "f16[32, 16][16, 1]", arg31_1: "f16[32][1]", arg32_1: "f16[1280, 32][32, 1]", arg33_1: "f16[1280][1]", arg34_1: "f16[1280, 16][16, 1]", arg35_1: "f16[1280][1]"):
E                # File: /transformers/src/transformers/models/timesfm/modeling_timesfm.py:742 in forward, code: inputs = [ts[-fcontext_len:] for ts in past_values]
E               select: "f16[s5][1]" = torch.ops.aten.select.int(arg1_1, 0, 0)
E               select_1: "f16[s5][1]" = torch.ops.aten.select.int(arg1_1, 0, 1)
E               select_2: "f16[s5][1]" = torch.ops.aten.select.int(arg1_1, 0, 2);  arg1_1 = None
E
E                # File: /transformers/src/transformers/models/timesfm/modeling_timesfm.py:742 in <listcomp>, code: inputs = [ts[-fcontext_len:] for ts in past_values]
E               slice_1: "f16[s5][1]" = torch.ops.aten.slice.Tensor(select, 0, -512, 9223372036854775807);  select = None
E               slice_2: "f16[s5][1]" = torch.ops.aten.slice.Tensor(select_1, 0, -512, 9223372036854775807);  select_1 = None
E               slice_3: "f16[s5][1]" = torch.ops.aten.slice.Tensor(select_2, 0, -512, 9223372036854775807);  select_2 = None
E
E                # File: /transformers/src/transformers/models/timesfm/modeling_timesfm.py:743 in <listcomp>, code: inp_min = torch.min(torch.stack([torch.min(ts) for ts in inputs]))
E               min_1: "f16[][]" = torch.ops.aten.min.default(slice_1)
E               min_2: "f16[][]" = torch.ops.aten.min.default(slice_2)
E               min_3: "f16[][]" = torch.ops.aten.min.default(slice_3)
E
E                # File: /transformers/src/transformers/models/timesfm/modeling_timesfm.py:743 in forward, code: inp_min = torch.min(torch.stack([torch.min(ts) for ts in inputs]))
E               unsqueeze: "f16[1][1]" = torch.ops.aten.unsqueeze.default(min_1, 0);  min_1 = None
E               unsqueeze_1: "f16[1][1]" = torch.ops.aten.unsqueeze.default(min_2, 0);  min_2 = None
E               unsqueeze_2: "f16[1][1]" = torch.ops.aten.unsqueeze.default(min_3, 0);  min_3 = None
E               cat: "f16[3][1]" = torch.ops.aten.cat.default([unsqueeze, unsqueeze_1, unsqueeze_2]);  unsqueeze = unsqueeze_1 = unsqueeze_2 = None
E               min_4: "f16[][]" = torch.ops.aten.min.default(cat);  cat = None
E
E                # File: /transformers/src/transformers/models/timesfm/modeling_timesfm.py:639 in _preprocess, code: padding = torch.zeros(input_len + self.horizon_len, dtype=ts.dtype, device=ts.device)
E               add_12: "Sym(s5 + 128)" = arg0_1 + 128
E               full: "f16[s5 + 128][1]" = torch.ops.aten.full.default([add_12], 0, dtype = torch.float16, layout = torch.strided, device = device(type='cuda', index=0), pin_memory = False)      
E
E                # File: /transformers/src/transformers/models/timesfm/modeling_timesfm.py:641 in _preprocess, code: num_front_pad = self.context_len - input_len
E               sub_7: "Sym(512 - s5)" = 512 - arg0_1;  arg0_1 = None
E
E                # File: /transformers/src/transformers/models/timesfm/modeling_timesfm.py:642 in _preprocess, code: ts = torch.cat([torch.zeros(num_front_pad, dtype=ts.dtype, device=ts.device), ts], dim=0)
E               full_1: "f16[512 - s5][1]" = torch.ops.aten.full.default([sub_7], 0, dtype = torch.float16, layout = torch.strided, device = device(type='cuda', index=0), pin_memory = False)     
E               cat_1: "f16[512][1]" = torch.ops.aten.cat.default([full_1, slice_1]);  full_1 = slice_1 = None
E
E                # File: /transformers/src/transformers/models/timesfm/modeling_timesfm.py:643 in _preprocess, code: padding = torch.cat([torch.ones(num_front_pad, dtype=ts.dtype, device=padding.device), padding], dim=0)
E               full_2: "f16[512 - s5][1]" = torch.ops.aten.full.default([sub_7], 1, dtype = torch.float16, layout = torch.strided, device = device(type='cuda', index=0), pin_memory = False)     
E               cat_2: "f16[640][1]" = torch.ops.aten.cat.default([full_2, full]);  full_2 = full = None
E
E                # File: /transformers/src/transformers/models/timesfm/modeling_timesfm.py:650 in _preprocess, code: inp_freq.append(freq[i])
E               select_3: "i64[][]" = torch.ops.aten.select.int(arg2_1, 0, 0)
E
E                # File: /transformers/src/transformers/models/timesfm/modeling_timesfm.py:639 in _preprocess, code: padding = torch.zeros(input_len + self.horizon_len, dtype=ts.dtype, device=ts.device)
E               full_3: "f16[s5 + 128][1]" = torch.ops.aten.full.default([add_12], 0, dtype = torch.float16, layout = torch.strided, device = device(type='cuda', index=0), pin_memory = False)    
E
E                # File: /transformers/src/transformers/models/timesfm/modeling_timesfm.py:642 in _preprocess, code: ts = torch.cat([torch.zeros(num_front_pad, dtype=ts.dtype, device=ts.device), ts], dim=0)
E               full_4: "f16[512 - s5][1]" = torch.ops.aten.full.default([sub_7], 0, dtype = torch.float16, layout = torch.strided, device = device(type='cuda', index=0), pin_memory = False)     
E               cat_3: "f16[512][1]" = torch.ops.aten.cat.default([full_4, slice_2]);  full_4 = slice_2 = None
E
E                # File: /transformers/src/transformers/models/timesfm/modeling_timesfm.py:643 in _preprocess, code: padding = torch.cat([torch.ones(num_front_pad, dtype=ts.dtype, device=padding.device), padding], dim=0)
E               full_5: "f16[512 - s5][1]" = torch.ops.aten.full.default([sub_7], 1, dtype = torch.float16, layout = torch.strided, device = device(type='cuda', index=0), pin_memory = False)     
E               cat_4: "f16[640][1]" = torch.ops.aten.cat.default([full_5, full_3]);  full_5 = full_3 = None
E
E                # File: /transformers/src/transformers/models/timesfm/modeling_timesfm.py:650 in _preprocess, code: inp_freq.append(freq[i])
E               select_4: "i64[][]" = torch.ops.aten.select.int(arg2_1, 0, 1)
E
E                # File: /transformers/src/transformers/models/timesfm/modeling_timesfm.py:639 in _preprocess, code: padding = torch.zeros(input_len + self.horizon_len, dtype=ts.dtype, device=ts.device)
E               full_6: "f16[s5 + 128][1]" = torch.ops.aten.full.default([add_12], 0, dtype = torch.float16, layout = torch.strided, device = device(type='cuda', index=0), pin_memory = False);  add_12 = None
E
E                # File: /transformers/src/transformers/models/timesfm/modeling_timesfm.py:642 in _preprocess, code: ts = torch.cat([torch.zeros(num_front_pad, dtype=ts.dtype, device=ts.device), ts], dim=0)
E               full_7: "f16[512 - s5][1]" = torch.ops.aten.full.default([sub_7], 0, dtype = torch.float16, layout = torch.strided, device = device(type='cuda', index=0), pin_memory = False)     
E               cat_5: "f16[512][1]" = torch.ops.aten.cat.default([full_7, slice_3]);  full_7 = slice_3 = None
E
E                # File: /transformers/src/transformers/models/timesfm/modeling_timesfm.py:643 in _preprocess, code: padding = torch.cat([torch.ones(num_front_pad, dtype=ts.dtype, device=padding.device), padding], dim=0)
E               full_8: "f16[512 - s5][1]" = torch.ops.aten.full.default([sub_7], 1, dtype = torch.float16, layout = torch.strided, device = device(type='cuda', index=0), pin_memory = False);  sub_7 = None
E               cat_6: "f16[640][1]" = torch.ops.aten.cat.default([full_8, full_6]);  full_8 = full_6 = None
E
E                # File: /transformers/src/transformers/models/timesfm/modeling_timesfm.py:650 in _preprocess, code: inp_freq.append(freq[i])
E               select_5: "i64[][]" = torch.ops.aten.select.int(arg2_1, 0, 2);  arg2_1 = None
E
E                # File: /transformers/src/transformers/models/timesfm/modeling_timesfm.py:653 in _preprocess, code: torch.stack(input_ts, dim=0),
E               cat_7: "f16[1536][1]" = torch.ops.aten.cat.default([cat_1, cat_3, cat_5]);  cat_1 = cat_3 = cat_5 = None
E               view: "f16[3, 512][512, 1]" = torch.ops.aten.view.default(cat_7, [3, 512]);  cat_7 = None
E
E                # File: /transformers/src/transformers/models/timesfm/modeling_timesfm.py:654 in _preprocess, code: torch.stack(input_padding, dim=0),
E               cat_8: "f16[1920][1]" = torch.ops.aten.cat.default([cat_2, cat_4, cat_6]);  cat_2 = cat_4 = cat_6 = None
E               view_1: "f16[3, 640][640, 1]" = torch.ops.aten.view.default(cat_8, [3, 640]);  cat_8 = None
E
E                # File: /transformers/src/transformers/models/timesfm/modeling_timesfm.py:655 in _preprocess, code: torch.tensor(inp_freq, dtype=torch.int32).reshape(-1, 1),
E               convert_element_type: "i32[][]" = torch.ops.prims.convert_element_type.default(select_3, torch.int32);  select_3 = None
E               device_put: "i32[][]" = torch.ops.prims.device_put.default(convert_element_type, device(type='cpu'));  convert_element_type = None
E               convert_element_type_1: "i32[][]" = torch.ops.prims.convert_element_type.default(select_4, torch.int32);  select_4 = None
E               device_put_1: "i32[][]" = torch.ops.prims.device_put.default(convert_element_type_1, device(type='cpu'));  convert_element_type_1 = None
E               convert_element_type_2: "i32[][]" = torch.ops.prims.convert_element_type.default(select_5, torch.int32);  select_5 = None
E               device_put_2: "i32[][]" = torch.ops.prims.device_put.default(convert_element_type_2, device(type='cpu'));  convert_element_type_2 = None
E               unsqueeze_3: "i32[1][1]" = torch.ops.aten.unsqueeze.default(device_put, 0);  device_put = None
E               unsqueeze_4: "i32[1][1]" = torch.ops.aten.unsqueeze.default(device_put_1, 0);  device_put_1 = None
E               unsqueeze_5: "i32[1][1]" = torch.ops.aten.unsqueeze.default(device_put_2, 0);  device_put_2 = None
E               cat_9: "i32[3][1]" = torch.ops.aten.cat.default([unsqueeze_3, unsqueeze_4, unsqueeze_5]);  unsqueeze_3 = unsqueeze_4 = unsqueeze_5 = None
E               view_5: "i32[3, 1][1, 1]" = torch.ops.aten.view.default(cat_9, [-1, 1]);  cat_9 = None
E
E                # File: /transformers/src/transformers/models/timesfm/modeling_timesfm.py:769 in forward, code: inp_freq = inp_freq.to(device)
E               device_put_3: "i32[3, 1][1, 1]" = torch.ops.prims.device_put.default(view_5, device(type='cuda', index=0));  view_5 = None
E
E                # File: /transformers/src/transformers/models/timesfm/modeling_timesfm.py:784 in forward, code: current_padding = input_padding[:, 0 : final_out.shape[1]]
E               slice_5: "f16[3, 512][640, 1]" = torch.ops.aten.slice.Tensor(view_1, 1, 0, 512);  view_1 = None
E
E                # File: /transformers/src/transformers/models/timesfm/modeling_timesfm.py:785 in forward, code: input_ts = final_out[:, -fcontext_len:]
E               slice_7: "f16[3, 512][512, 1]" = torch.ops.aten.slice.Tensor(view, 1, -512, 9223372036854775807);  view = None
E
E                # File: /transformers/src/transformers/models/timesfm/modeling_timesfm.py:786 in forward, code: input_padding = current_padding[:, -fcontext_len:]
E               slice_9: "f16[3, 512][640, 1]" = torch.ops.aten.slice.Tensor(slice_5, 1, -512, 9223372036854775807);  slice_5 = None
E
E                # File: /transformers/src/transformers/models/timesfm/modeling_timesfm.py:390 in forward, code: patched_inputs = past_values.view(bsize, -1, self.config.patch_length)
E               view_6: "f16[3, 16, 32][512, 32, 1]" = torch.ops.aten.view.default(slice_7, [3, -1, 32]);  slice_7 = None
E
E                # File: /transformers/src/transformers/models/timesfm/modeling_timesfm.py:391 in forward, code: patched_pads = past_values_padding.view(bsize, -1, self.config.patch_length)      
E               view_7: "f16[3, 16, 32][640, 32, 1]" = torch.ops.aten.view.default(slice_9, [3, -1, 32]);  slice_9 = None
E
E                # File: /transformers/src/transformers/models/timesfm/modeling_timesfm.py:394 in forward, code: torch.abs(patched_pads - 1.0) < self.config.tolerance,
E               sub_16: "f16[3, 16, 32][512, 32, 1]" = torch.ops.aten.sub.Tensor(view_7, 1.0)
E               abs_1: "f16[3, 16, 32][512, 32, 1]" = torch.ops.aten.abs.default(sub_16);  sub_16 = None
E
E               # No stacktrace found for following nodes
E               lt_tensor: "b8[3, 16, 32][512, 32, 1]" = torch.ops.aten.lt.Tensor(abs_1, arg3_1);  abs_1 = None
E
E                # File: /transformers/src/transformers/models/timesfm/modeling_timesfm.py:395 in forward, code: torch.tensor(0.0, dtype=patched_inputs.dtype, device=patched_inputs.device),      
E               _tensor_constant0: "f16[][]" = self._tensor_constant0
E               lift_fresh_copy: "f16[][]" = torch.ops.aten.lift_fresh_copy.default(_tensor_constant0);  _tensor_constant0 = None
E
E                # File: /transformers/src/transformers/models/timesfm/modeling_timesfm.py:393 in forward, code: patched_inputs = torch.where(
E               where: "f16[3, 16, 32][512, 32, 1]" = torch.ops.aten.where.self(lt_tensor, lift_fresh_copy, view_6);  lt_tensor = lift_fresh_copy = view_6 = None
E
E               # No stacktrace found for following nodes
E               convert_element_type_default: "f32[][]" = torch.ops.prims.convert_element_type.default(arg4_1, torch.float32)
E               sub_tensor: "f16[3, 16, 32][512, 32, 1]" = torch.ops.aten.sub.Tensor(where, convert_element_type_default);  convert_element_type_default = None
E
E                # File: /transformers/src/transformers/models/timesfm/modeling_timesfm.py:399 in forward, code: torch.abs(patched_inputs - self.config.pad_val) < self.config.tolerance,
E               abs_2: "f16[3, 16, 32][512, 32, 1]" = torch.ops.aten.abs.default(sub_tensor);  sub_tensor = None
E
E               # No stacktrace found for following nodes
E               lt_tensor_1: "b8[3, 16, 32][512, 32, 1]" = torch.ops.aten.lt.Tensor(abs_2, arg3_1);  abs_2 = None
E
E                # File: /transformers/src/transformers/models/timesfm/modeling_timesfm.py:400 in forward, code: torch.tensor(1.0, dtype=patched_pads.dtype, device=patched_pads.device),
E               _tensor_constant1: "f16[][]" = self._tensor_constant1
E               lift_fresh_copy_1: "f16[][]" = torch.ops.aten.lift_fresh_copy.default(_tensor_constant1);  _tensor_constant1 = None
E
E                # File: /transformers/src/transformers/models/timesfm/modeling_timesfm.py:398 in forward, code: patched_pads = torch.where(
E               where_1: "f16[3, 16, 32][512, 32, 1]" = torch.ops.aten.where.self(lt_tensor_1, lift_fresh_copy_1, view_7);  lt_tensor_1 = lift_fresh_copy_1 = view_7 = None
E
E                # File: /transformers/src/transformers/models/timesfm/modeling_timesfm.py:526 in _timesfm_masked_mean_std, code: pad_sum = torch.sum(1 - padding, dim=2)
E               sub_18: "f16[3, 16, 32][512, 32, 1]" = torch.ops.aten.sub.Tensor(1, where_1)
E               sum_1: "f16[3, 16][16, 1]" = torch.ops.aten.sum.dim_IntList(sub_18, [2]);  sub_18 = None
E
E                # File: /transformers/src/transformers/models/timesfm/modeling_timesfm.py:522 in _get_patch_index, code: indices = torch.argmax((arr >= 3).to(torch.int32), dim=1)
E               ge_13: "b8[3, 16][16, 1]" = torch.ops.aten.ge.Scalar(sum_1, 3)
E               convert_element_type_4: "i32[3, 16][16, 1]" = torch.ops.prims.convert_element_type.default(ge_13, torch.int32);  ge_13 = None
E               argmax: "i64[3][1]" = torch.ops.aten.argmax.default(convert_element_type_4, 1);  convert_element_type_4 = None
E
E                # File: /transformers/src/transformers/models/timesfm/modeling_timesfm.py:523 in _get_patch_index, code: row_sum = (arr >= 3).to(torch.int32).sum(dim=1)
E               ge_14: "b8[3, 16][16, 1]" = torch.ops.aten.ge.Scalar(sum_1, 3);  sum_1 = None
E               convert_element_type_5: "i32[3, 16][16, 1]" = torch.ops.prims.convert_element_type.default(ge_14, torch.int32);  ge_14 = None
E               sum_2: "i64[3][1]" = torch.ops.aten.sum.dim_IntList(convert_element_type_5, [1]);  convert_element_type_5 = None
E
E                # File: /transformers/src/transformers/models/timesfm/modeling_timesfm.py:524 in _get_patch_index, code: return torch.where(row_sum == 0, arr.shape[1] - 1, indices)
E               eq_39: "b8[3][1]" = torch.ops.aten.eq.Scalar(sum_2, 0);  sum_2 = None
E               scalar_tensor: "i64[][]" = torch.ops.aten.scalar_tensor.default(15, dtype = torch.int64, layout = torch.strided, device = device(type='cuda', index=0))
E               where_2: "i64[3][1]" = torch.ops.aten.where.self(eq_39, scalar_tensor, argmax);  eq_39 = scalar_tensor = argmax = None
E
E                # File: /transformers/src/transformers/models/timesfm/modeling_timesfm.py:528 in _timesfm_masked_mean_std, code: bidxs = torch.arange(inputs.shape[0])
E               iota: "i64[3][1]" = torch.ops.prims.iota.default(3, start = 0, step = 1, dtype = torch.int64, device = device(type='cpu'), requires_grad = False)
E
E                # File: /transformers/src/transformers/models/timesfm/modeling_timesfm.py:530 in _timesfm_masked_mean_std, code: arr = inputs[bidxs, patch_indices, :]
E               index: "f16[3, 32][32, 1]" = torch.ops.aten.index.Tensor(where, [iota, where_2])
E
E                # File: /transformers/src/transformers/models/timesfm/modeling_timesfm.py:531 in _timesfm_masked_mean_std, code: pad = padding[bidxs, patch_indices, :]
E               index_1: "f16[3, 32][32, 1]" = torch.ops.aten.index.Tensor(where_1, [iota, where_2]);  iota = where_2 = None
E
E                # File: /transformers/src/transformers/models/timesfm/modeling_timesfm.py:534 in _timesfm_masked_mean_std, code: mask = 1 - pad
E               sub_19: "f16[3, 32][32, 1]" = torch.ops.aten.sub.Tensor(1, index_1);  index_1 = None
E
E                # File: /transformers/src/transformers/models/timesfm/modeling_timesfm.py:537 in _timesfm_masked_mean_std, code: num_valid_elements = torch.sum(mask, dim=1)
E               sum_3: "f16[3][1]" = torch.ops.aten.sum.dim_IntList(sub_19, [1])
E
E                # File: /transformers/src/transformers/models/timesfm/modeling_timesfm.py:539 in _timesfm_masked_mean_std, code: num_valid_elements == 0,
E               eq_40: "b8[3][1]" = torch.ops.aten.eq.Scalar(sum_3, 0)
E
E                # File: /transformers/src/transformers/models/timesfm/modeling_timesfm.py:540 in _timesfm_masked_mean_std, code: torch.tensor(1, dtype=num_valid_elements.dtype, device=num_valid_elements.device),
E               _tensor_constant2: "f16[][]" = self._tensor_constant2
E               lift_fresh_copy_2: "f16[][]" = torch.ops.aten.lift_fresh_copy.default(_tensor_constant2);  _tensor_constant2 = None
E


(very very long outputs)


E               ge_16: "b8[][]" = torch.ops.aten.ge.Scalar(min_4, 0);  min_4 = None
E               return (ge_16, clone_4, select_9, add_39, div, where_5)
E
E       
E       Original traceback:
E         File "/transformers/src/transformers/models/timesfm/modeling_timesfm.py", line 787, in forward
E           decoder_output = self.decoder(
E         File "/transformers/src/transformers/utils/generic.py", line 961, in wrapper
E           output = func(self, *args, **kwargs)
E         File "/transformers/src/transformers/models/timesfm/modeling_timesfm.py", line 403, in forward
E           patched_inputs, stats = self._forward_transform(patched_inputs, patched_pads)
E         File "/transformers/src/transformers/models/timesfm/modeling_timesfm.py", line 365, in _forward_transform
E           torch.tensor(self.config.pad_val, dtype=outputs.dtype, device=outputs.device),
E       
E       
E       Set TORCHDYNAMO_VERBOSE=1 for the internal stack trace (please do this especially if you're reporting a bug to PyTorch). For even more developer context, set TORCH_LOGS="+dynamo"

/usr/local/lib/python3.10/dist-packages/torch/_ops.py:829: BackendCompilerFailed

Versions

Collecting environment information...
PyTorch version: 2.8.0+cu126
Is debug build: False
CUDA used to build PyTorch: 12.6
ROCM used to build PyTorch: N/A

OS: Ubuntu 22.04.4 LTS (x86_64)
GCC version: (Ubuntu 11.4.0-1ubuntu1~22.04) 11.4.0
Clang version: Could not collect
CMake version: Could not collect
Libc version: glibc-2.35

Python version: 3.10.12 (main, May 27 2025, 17:12:29) [GCC 11.4.0] (64-bit runtime)
Python platform: Linux-5.10.238-231.953.amzn2.x86_64-x86_64-with-glibc2.35
Is CUDA available: True
CUDA runtime version: Could not collect
CUDA_MODULE_LOADING set to: LAZY
GPU models and configuration: GPU 0: NVIDIA A10G
Nvidia driver version: 550.163.01
cuDNN version: Probably one of the following:
/usr/lib/x86_64-linux-gnu/libcudnn.so.9.3.0
/usr/lib/x86_64-linux-gnu/libcudnn_adv.so.9.3.0
/usr/lib/x86_64-linux-gnu/libcudnn_cnn.so.9.3.0
/usr/lib/x86_64-linux-gnu/libcudnn_engines_precompiled.so.9.3.0
/usr/lib/x86_64-linux-gnu/libcudnn_engines_runtime_compiled.so.9.3.0
/usr/lib/x86_64-linux-gnu/libcudnn_graph.so.9.3.0
/usr/lib/x86_64-linux-gnu/libcudnn_heuristic.so.9.3.0
/usr/lib/x86_64-linux-gnu/libcudnn_ops.so.9.3.0
Is XPU available: False
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True

CPU:
Architecture: x86_64
CPU op-mode(s): 32-bit, 64-bit
Address sizes: 48 bits physical, 48 bits virtual
Byte Order: Little Endian
CPU(s): 16
On-line CPU(s) list: 0-15
Vendor ID: AuthenticAMD
Model name: AMD EPYC 7R32
CPU family: 23
Model: 49
Thread(s) per core: 2
Core(s) per socket: 8
Socket(s): 1
Stepping: 0
BogoMIPS: 5599.89
Flags: fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush mmx fxsr sse sse2 ht syscall nx mmxext fxsr_opt pdpe1gb rdtscp lm constant
_tsc rep_good nopl nonstop_tsc cpuid extd_apicid aperfmperf tsc_known_freq pni pclmulqdq ssse3 fma cx16 sse4_1 sse4_2 movbe popcnt aes xsave avx f16c rdrand hypervisor lahf_lm cmp_legacy cr8_lega
cy abm sse4a misalignsse 3dnowprefetch topoext ssbd ibrs ibpb stibp vmmcall fsgsbase bmi1 avx2 smep bmi2 rdseed adx smap clflushopt clwb sha_ni xsaveopt xsavec xgetbv1 clzero xsaveerptr rdpru wbnoinvd arat npt nrip_save rdpid
Hypervisor vendor: KVM
Virtualization type: full
L1d cache: 256 KiB (8 instances)
L1i cache: 256 KiB (8 instances)
L2 cache: 4 MiB (8 instances)
L3 cache: 32 MiB (2 instances)
NUMA node(s): 1
NUMA node0 CPU(s): 0-15
Vulnerability Gather data sampling: Not affected
Vulnerability Indirect target selection: Not affected
Vulnerability Itlb multihit: Not affected
Vulnerability L1tf: Not affected
Vulnerability Mds: Not affected
Vulnerability Meltdown: Not affected
Vulnerability Mmio stale data: Not affected
Vulnerability Reg file data sampling: Not affected
Vulnerability Retbleed: Mitigation; untrained return thunk; SMT vulnerable
Vulnerability Spec rstack overflow: Mitigation; safe RET
Vulnerability Spec store bypass: Mitigation; Speculative Store Bypass disabled via prctl and seccomp
Vulnerability Spectre v1: Mitigation; usercopy/swapgs barriers and __user pointer sanitization
Vulnerability Spectre v2: Mitigation; Retpolines, IBPB disabled, STIBP disabled, RSB filling, PBRSB-eIBRS Not affected
Vulnerability Srbds: Not affected
Vulnerability Tsx async abort: Not affected

Versions of relevant libraries:
[pip3] mypy_extensions==1.1.0
[pip3] natten==0.17.4+torch250cu121
[pip3] numpy==1.26.4
[pip3] nvidia-cublas-cu12==12.6.4.1
[pip3] nvidia-cuda-cupti-cu12==12.6.80
[pip3] nvidia-cuda-nvrtc-cu12==12.6.77
[pip3] nvidia-cuda-runtime-cu12==12.6.77
[pip3] nvidia-cudnn-cu12==9.10.2.21
[pip3] nvidia-cufft-cu12==11.3.0.4
[pip3] nvidia-curand-cu12==10.3.7.77
[pip3] nvidia-cusolver-cu12==11.7.1.2
[pip3] nvidia-cusparse-cu12==12.5.4.2
[pip3] nvidia-cusparselt-cu12==0.7.1
[pip3] nvidia-nccl-cu12==2.27.3
[pip3] nvidia-nvjitlink-cu12==12.6.85
[pip3] nvidia-nvtx-cu12==12.6.77
[pip3] onnx==1.18.0
[pip3] onnxconverter-common==1.15.0
[pip3] onnxruntime==1.22.1
[pip3] onnxruntime-tools==1.7.0
[pip3] tf2onnx==1.8.4
[pip3] torch==2.8.0+cu126
[pip3] torchaudio==2.8.0+cu126
[pip3] torchcodec==0.4.0+cu126
[pip3] torchvision==0.23.0+cu126
[pip3] triton==3.4.0
[conda] Could not collect

cc @chauhang @penguinwu

Metadata

Metadata

Assignees

No one assigned

    Type

    No type

    Projects

    No projects

    Milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions