KEMBAR78
[inductor] [silent incorrectness] Multiple internal `torch.rand` can lead to inconsistent results with eager · Issue #151524 · pytorch/pytorch · GitHub
Skip to content

[inductor] [silent incorrectness] Multiple internal torch.rand can lead to inconsistent results with eager #151524

@shaoyuyoung

Description

@shaoyuyoung

🐛 Describe the bug

symptom: If we just use one-time torch.rand in forward function, the output is right. However, output is inconsistent when we use at least two times torch.rand. The multiple uses of internal torch.rand don't respect the fallback_random (?)
device backend: both CPP and triton

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch._inductor import config

config.fallback_random = True
torch.set_grad_enabled(False)
torch.manual_seed(0)


class Model(nn.Module):

    def __init__(self):
        super().__init__()

    def forward(self):
        x = torch.rand(1)
        x = torch.rand(1)
        return x


model = Model()

inputs = []


def run_test(model, inputs, backend):
    if backend != "eager":
        model = torch.compile(model, backend=backend)
    torch.manual_seed(0)
    output = model()
    return output


output = run_test(model, inputs, 'eager')
c_output = run_test(model, inputs, 'inductor')
print(output)
print(c_output)

Error logs

tensor([0.7682])
tensor([0.4963])

Versions

nightly 20250414

cc @ezyang @gchanan @zou3519 @kadeng @msaroufim @chauhang @penguinwu @voznesenskym @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @ipiszy @chenyang78 @muchulee8 @amjames @aakhundov

Metadata

Metadata

Assignees

Labels

high prioritymodule: inductoroncall: pt2triagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate module

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions