-
Notifications
You must be signed in to change notification settings - Fork 25.7k
Closed
Labels
module: dynamic shapesmodule: inductoroncall: pt2triagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate moduleThis issue has been looked at a team member, and triaged and prioritized into an appropriate module
Description
🐛 Describe the bug
The following function breaks:
import torch
@torch.compile(dynamic=True)
def foo(a):
return a - torch.zeros(3)
foo(4)Related to #108067, but breaks even after fixing that issue.
File "torch/_inductor/graph.py", line 610, in call_function
out = lowerings[target](*args, **kwargs)
File "torch/_inductor/lowering.py", line 279, in wrapped
out = decomp_fn(*args, **kwargs)
File "torch/_inductor/lowering.py", line 388, in inner
assert isinstance(other, ir.BaseConstant) or len(ranges) == len(
torch._dynamo.exc.BackendCompilerFailed: backend='inductor' raised:
LoweringException: AssertionError: ndim mismatch <function ops_wrapper.<locals>.fn at 0x7f491d8fcb80> () [3]
target: aten.sub.Tensor
args[0]: s0
args[1]: TensorBox(StorageBox(
Pointwise(
'cpu',
torch.float32,
def inner_fn(index):
i0 = index
tmp0 = ops.constant(0, torch.float32)
return tmp0
,
ranges=[3],
origin_node=full_default,
origins={full_default}
)
))Versions
After #180160 (it also breaks on master, but with a different error).
cc @ezyang @msaroufim @wconstab @bdhirsh @anijain2305 @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @Xia-Weiwen @wenzhe-nrv @jiayisunx @peterbell10 @ipiszy @ngimel @yf225 @chenyang78 @kadeng @muchulee8 @aakhundov
Metadata
Metadata
Assignees
Labels
module: dynamic shapesmodule: inductoroncall: pt2triagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate moduleThis issue has been looked at a team member, and triaged and prioritized into an appropriate module