KEMBAR78
Fix DCE eliminating random operations by improving is_impure() (#151524) by soumith · Pull Request #157981 · pytorch/pytorch · GitHub
Skip to content

Conversation

@soumith
Copy link
Member

@soumith soumith commented Jul 10, 2025

DCE was incorrectly eliminating unused random operations like torch.rand() that have global RNG side effects, causing inconsistent results between eager and compiled execution modes.

Root cause: Python random functions (torch.rand, torch.randn, etc.) don't have the _nondeterministic_seeded attribute, so node.is_impure() returns False, allowing DCE to eliminate them despite advancing global RNG state.

Solution: Enhanced is_impure() in torch/fx/node.py to recognize Python random functions and mark them as impure when they use global RNG, regardless of the impure_random parameter setting. This ensures consistency between eager and compiled execution even when config.fallback_random=False.

Key features:

  • Handles comprehensive list of random functions: rand, randn, randint, randperm, rand_like, randn_like, randint_like, normal, poisson, bernoulli, multinomial
  • Generator optimization: Only marks as impure when using global RNG (no generator or generator=None). Operations with explicit generators don't affect global state and can be optimized.
  • Works with both impure_random=True and impure_random=False cases
  • Cleaner architecture: addresses root cause rather than working around it

Tests: Enhanced test_impure_random to verify both FX tracing and AOT compilation codepaths, ensuring random operations are preserved and eager/compiled execution consistency is maintained.

🤖 Generated with Claude Code

Fixes #151524

cc @ezyang @SherlockNoMad @EikanWang @jgong5 @wenzhe-nrv

@pytorch-bot
Copy link

pytorch-bot bot commented Jul 10, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/157981

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit 11ced11 with merge base 86251ef (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@pytorch-bot pytorch-bot bot added the release notes: fx release notes category label Jul 10, 2025
@soumith soumith requested review from anijain2305 and mlazos July 10, 2025 00:23
torch/fx/node.py Outdated
}

if self.target in _random_functions:
# Only impure if using global RNG (no generator or generator=None)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I mean technically they're still impure if generator is used. If I have some calls to rand that use a specific generator and I DCE them, the later generator calls will get modified.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I thought impure in this context only refers to global side-effects, not local mutations. Am I wrong?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Edward is right here. it's not about purity of global or local, but whether its a pure function or not.
Especially in the context of fixing this issue (i.e. disrepancy between eager and compiled), we should be treating this as impure

…51524)

Random operations with explicit generators can still affect observable behavior
when eliminated, causing different generator states between eager and compiled
execution modes. All random operations should be preserved by DCE.

🤖 Generated with [Claude Code](https://claude.ai/code)

Co-Authored-By: Claude <noreply@anthropic.com>
@soumith
Copy link
Member Author

soumith commented Jul 10, 2025

@pytorchbot merge

@pytorch-bot pytorch-bot bot added the ciflow/trunk Trigger trunk jobs on your pull request label Jul 10, 2025
@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ciflow/trunk Trigger trunk jobs on your pull request fx Merged release notes: fx release notes category

Projects

None yet

Development

Successfully merging this pull request may close these issues.

[inductor] [silent incorrectness] Multiple internal torch.rand can lead to inconsistent results with eager

5 participants