KEMBAR78
fix incorrect c10::SymFloat::sqrt by bdhirsh · Pull Request #141728 · pytorch/pytorch · GitHub
Skip to content

Conversation

@bdhirsh
Copy link
Contributor

@bdhirsh bdhirsh commented Nov 27, 2024

[ghstack-poisoned]
@pytorch-bot
Copy link

pytorch-bot bot commented Nov 27, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/141728

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit f3ccc5d with merge base 9125e91 (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

bdhirsh added a commit that referenced this pull request Nov 27, 2024
ghstack-source-id: b1dca33
Pull Request resolved: #141728
@Skylion007
Copy link
Collaborator

Yikes! This is a good backport candidate.

@drisspg
Copy link
Contributor

drisspg commented Nov 28, 2024

Summary

  1. I think it is unlikely that this value ends up being symbolic (deff possible but unlikely)
  2. For math path w/ symfloats we are also got lucky in symfloat case see below
  3. For flash-attention this is only a problem if this is actually a symfloat, but since we were using as_float_unchecked() we already had bigger problems than this scale miscalc. I don't think the other calls within kernel dispatch could ever run the symfloat path
  4. For non symfloat case we use std::sqrt which is correct and why this didnt show up in any UTs since I imagine we didn't exercise this dynamic last dim in any tests.

Math path reasoning:
We did a trick that to reduce the numerical deviation by doing double sqrt one on query and one on key instead of just once on the product of q@k.T

We only ever used sqrt as
bad_sqrt = pow(scale, -0.5)
good_sqrt = pow(scale, 0.5)

So for good sqrt you have

good_sqrt(x) = x^(0.5)
good_sqrt(good_sqrt(scale)) = (scale^0.5)^0.5 = scale^0.25

and for bad sqrt you have

bad_sqrt(x) = x^(-0.5)
bad_sqrt(bad_sqrt(scale)) = (scale^-0.5)^-0.5 = scale^0.25

Overall v important fix and great catch.I had a heart attack reading this at first and needed to do some sanity checks as to why this never showed up before, I think the existing blast radius is somewhat well contained

@Skylion007 Skylion007 added the topic: bug fixes topic category label Nov 29, 2024
Fixes the silent correctness for SDPA in #141710




[ghstack-poisoned]
@bdhirsh bdhirsh added ciflow/trunk Trigger trunk jobs on your pull request release notes: composability release notes category labels Dec 3, 2024
Fixes the silent correctness for SDPA in #141710




[ghstack-poisoned]
bdhirsh added a commit that referenced this pull request Dec 3, 2024
ghstack-source-id: e9187a2
Pull Request resolved: #141728
@albanD albanD removed their request for review December 3, 2024 21:31
@bdhirsh
Copy link
Contributor Author

bdhirsh commented Dec 3, 2024

@pytorchbot merge

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

pytorchmergebot pushed a commit that referenced this pull request Dec 5, 2024
Fixes #142076. Under compile, functional collectives are supposed to **not** return `AsyncCollectiveTensor`, and instead immediately issue calls to `wait_tensor()` (that we rely on the compiler to reorder as necessary.

This is done with a function `_are_we_tracing()`, that tries to detect if we are running from inside of the compiler. One of the checks it performs is `is_torchdynamo_compiling()` ([here](https://github.com/pytorch/pytorch/blob/main/torch/distributed/_functional_collectives.py#L808C8-L808C34)).

Unfortunately, this will always return False, even if dynamo is indeed tracing. The problem is that this function only returns true if dynamo **intercepts** the bytecode for `is_torchdynamo_compiling()`. However, this function is called during fake-tensor propagation, which is run as part of dynamo, but is not actually intercepted by dynamo itself.

One thing that we know is the case during dynamo tracing, however, is that a `FakeTensorMode` is active. So I tweaked the logic to assume that we are tracing if there is an active fake mode.

This could potentially have consequences for anybody running functional collectives with a fake mode directly, without compile in the loop. Although hopefully it's not too unreasonable to issue wait() calls immediately if you are running with fake tensor (presumably you only care about fake tensor propagation, in which case the wait() calls should technically be a no-op).

Pull Request resolved: #142075
Approved by: https://github.com/yifuwang, https://github.com/kwen2501
ghstack dependencies: #141725, #141728
pobin6 pushed a commit to pobin6/pytorch that referenced this pull request Dec 5, 2024
pobin6 pushed a commit to pobin6/pytorch that referenced this pull request Dec 5, 2024
Fixes pytorch#142076. Under compile, functional collectives are supposed to **not** return `AsyncCollectiveTensor`, and instead immediately issue calls to `wait_tensor()` (that we rely on the compiler to reorder as necessary.

This is done with a function `_are_we_tracing()`, that tries to detect if we are running from inside of the compiler. One of the checks it performs is `is_torchdynamo_compiling()` ([here](https://github.com/pytorch/pytorch/blob/main/torch/distributed/_functional_collectives.py#L808C8-L808C34)).

Unfortunately, this will always return False, even if dynamo is indeed tracing. The problem is that this function only returns true if dynamo **intercepts** the bytecode for `is_torchdynamo_compiling()`. However, this function is called during fake-tensor propagation, which is run as part of dynamo, but is not actually intercepted by dynamo itself.

One thing that we know is the case during dynamo tracing, however, is that a `FakeTensorMode` is active. So I tweaked the logic to assume that we are tracing if there is an active fake mode.

This could potentially have consequences for anybody running functional collectives with a fake mode directly, without compile in the loop. Although hopefully it's not too unreasonable to issue wait() calls immediately if you are running with fake tensor (presumably you only care about fake tensor propagation, in which case the wait() calls should technically be a no-op).

Pull Request resolved: pytorch#142075
Approved by: https://github.com/yifuwang, https://github.com/kwen2501
ghstack dependencies: pytorch#141725, pytorch#141728
AmdSampsa pushed a commit to AmdSampsa/pytorch that referenced this pull request Dec 9, 2024
AmdSampsa pushed a commit to AmdSampsa/pytorch that referenced this pull request Dec 9, 2024
Fixes pytorch#142076. Under compile, functional collectives are supposed to **not** return `AsyncCollectiveTensor`, and instead immediately issue calls to `wait_tensor()` (that we rely on the compiler to reorder as necessary.

This is done with a function `_are_we_tracing()`, that tries to detect if we are running from inside of the compiler. One of the checks it performs is `is_torchdynamo_compiling()` ([here](https://github.com/pytorch/pytorch/blob/main/torch/distributed/_functional_collectives.py#L808C8-L808C34)).

Unfortunately, this will always return False, even if dynamo is indeed tracing. The problem is that this function only returns true if dynamo **intercepts** the bytecode for `is_torchdynamo_compiling()`. However, this function is called during fake-tensor propagation, which is run as part of dynamo, but is not actually intercepted by dynamo itself.

One thing that we know is the case during dynamo tracing, however, is that a `FakeTensorMode` is active. So I tweaked the logic to assume that we are tracing if there is an active fake mode.

This could potentially have consequences for anybody running functional collectives with a fake mode directly, without compile in the loop. Although hopefully it's not too unreasonable to issue wait() calls immediately if you are running with fake tensor (presumably you only care about fake tensor propagation, in which case the wait() calls should technically be a no-op).

Pull Request resolved: pytorch#142075
Approved by: https://github.com/yifuwang, https://github.com/kwen2501
ghstack dependencies: pytorch#141725, pytorch#141728
@github-actions github-actions bot deleted the gh/bdhirsh/632/head branch January 3, 2025 02:08
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ciflow/inductor ciflow/trunk Trigger trunk jobs on your pull request Merged module: dynamo release notes: composability release notes category topic: bug fixes topic category

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants