-
Notifications
You must be signed in to change notification settings - Fork 25.7k
guard on flash attention SymFloat scale instead of incorrectly casting to float #141725
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/141725
Note: Links to docs will display an error until the docs builds have been completed. ✅ No FailuresAs of commit da90999 with merge base 9125e91 ( This comment was automatically generated by Dr. CI and updates every 15 minutes. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Use FILE and LINE
Specializing is better than silently wrong. Time to audit all other sites of this...
Need test.
|
There are a lot of other sites that need to be fixed. cc @drisspg |
|
I can audit the sites as part of this PR, thanks for the links. And landing specialization first sounds good |
|
For my understanding, if the size is a true float and not symbolic this returns the correct result? It is much less common in a user program (that isn't doing benchmarking) to vert the head dim size, which is way this hasn't shown up earlier |
yup that's right - that basically maps to this code: https://github.com/pytorch/pytorch/blob/main/c10/core/SymFloat.h#L41 (it is only safe to treat |
|
Looks like all of the other sites are inside of kernels (registered to backend dispatch keys), so |
Nah, we should use the guard in that case. I don't want |
|
sgtm, will update. (I'll try to land this PR tomorrow) |
…ctly casting to float" Fixes #141710. Previously, if we called flash attention with a `SymFloat` scale that was properly symbolic, we would unsafely cast its raw `SymFloat._data` into a `float`, which is pretty much guaranteed to give `NaN`. This avoids the NaNs in the linked issue, but I'm not sure if it's worth landing yet because we'll start specializing and recompiling for every distinct `scale` that is passed in (which in the dynamic shapes case, is some function of `query.size(-1)`). The real fix would be to ensure that the flash attention (and related) ops all accept a symbolic version of the `scale`. I'm not sure if we should use `SymFloat` or `Scalar` though - more discussion in the issue [ghstack-poisoned]
|
Removed the |
…ctly casting to float" Fixes #141710. Previously, if we called flash attention with a `SymFloat` scale that was properly symbolic, we would unsafely cast its raw `SymFloat._data` into a `float`, which is pretty much guaranteed to give `NaN`. This avoids the NaNs in the linked issue, but I'm not sure if it's worth landing yet because we'll start specializing and recompiling for every distinct `scale` that is passed in (which in the dynamic shapes case, is some function of `query.size(-1)`). The real fix would be to ensure that the flash attention (and related) ops all accept a symbolic version of the `scale`. I'm not sure if we should use `SymFloat` or `Scalar` though - more discussion in the issue cc jgong5 mingfeima XiaobingSuper sanchitintel ashokei jingxu10 voznesenskym penguinwu EikanWang Guobing-Chen zhuhaozhe blzheng wenzhe-nrv jiayisunx chenyang78 kadeng chauhang amjames [ghstack-poisoned]
Fixes the silent correctness for SDPA in #141710 Pull Request resolved: #141728 Approved by: https://github.com/Skylion007, https://github.com/ezyang, https://github.com/drisspg ghstack dependencies: #141725
Fixes #142076. Under compile, functional collectives are supposed to **not** return `AsyncCollectiveTensor`, and instead immediately issue calls to `wait_tensor()` (that we rely on the compiler to reorder as necessary. This is done with a function `_are_we_tracing()`, that tries to detect if we are running from inside of the compiler. One of the checks it performs is `is_torchdynamo_compiling()` ([here](https://github.com/pytorch/pytorch/blob/main/torch/distributed/_functional_collectives.py#L808C8-L808C34)). Unfortunately, this will always return False, even if dynamo is indeed tracing. The problem is that this function only returns true if dynamo **intercepts** the bytecode for `is_torchdynamo_compiling()`. However, this function is called during fake-tensor propagation, which is run as part of dynamo, but is not actually intercepted by dynamo itself. One thing that we know is the case during dynamo tracing, however, is that a `FakeTensorMode` is active. So I tweaked the logic to assume that we are tracing if there is an active fake mode. This could potentially have consequences for anybody running functional collectives with a fake mode directly, without compile in the loop. Although hopefully it's not too unreasonable to issue wait() calls immediately if you are running with fake tensor (presumably you only care about fake tensor propagation, in which case the wait() calls should technically be a no-op). Pull Request resolved: #142075 Approved by: https://github.com/yifuwang, https://github.com/kwen2501 ghstack dependencies: #141725, #141728
…g to float (pytorch#141725) Fixes pytorch#141710. Previously, if we called flash attention with a `SymFloat` scale that was properly symbolic, we would unsafely cast its raw `SymFloat._data` into a `float`, which is pretty much guaranteed to give `NaN`. This avoids the NaNs in the linked issue, but I'm not sure if it's worth landing yet because we'll start specializing and recompiling for every distinct `scale` that is passed in (which in the dynamic shapes case, is some function of `query.size(-1)`). The real fix would be to ensure that the flash attention (and related) ops all accept a symbolic version of the `scale`. I'm not sure if we should use `SymFloat` or `Scalar` though - more discussion in the issue Pull Request resolved: pytorch#141725 Approved by: https://github.com/ezyang
Fixes the silent correctness for SDPA in pytorch#141710 Pull Request resolved: pytorch#141728 Approved by: https://github.com/Skylion007, https://github.com/ezyang, https://github.com/drisspg ghstack dependencies: pytorch#141725
Fixes pytorch#142076. Under compile, functional collectives are supposed to **not** return `AsyncCollectiveTensor`, and instead immediately issue calls to `wait_tensor()` (that we rely on the compiler to reorder as necessary. This is done with a function `_are_we_tracing()`, that tries to detect if we are running from inside of the compiler. One of the checks it performs is `is_torchdynamo_compiling()` ([here](https://github.com/pytorch/pytorch/blob/main/torch/distributed/_functional_collectives.py#L808C8-L808C34)). Unfortunately, this will always return False, even if dynamo is indeed tracing. The problem is that this function only returns true if dynamo **intercepts** the bytecode for `is_torchdynamo_compiling()`. However, this function is called during fake-tensor propagation, which is run as part of dynamo, but is not actually intercepted by dynamo itself. One thing that we know is the case during dynamo tracing, however, is that a `FakeTensorMode` is active. So I tweaked the logic to assume that we are tracing if there is an active fake mode. This could potentially have consequences for anybody running functional collectives with a fake mode directly, without compile in the loop. Although hopefully it's not too unreasonable to issue wait() calls immediately if you are running with fake tensor (presumably you only care about fake tensor propagation, in which case the wait() calls should technically be a no-op). Pull Request resolved: pytorch#142075 Approved by: https://github.com/yifuwang, https://github.com/kwen2501 ghstack dependencies: pytorch#141725, pytorch#141728
…g to float (pytorch#141725) Fixes pytorch#141710. Previously, if we called flash attention with a `SymFloat` scale that was properly symbolic, we would unsafely cast its raw `SymFloat._data` into a `float`, which is pretty much guaranteed to give `NaN`. This avoids the NaNs in the linked issue, but I'm not sure if it's worth landing yet because we'll start specializing and recompiling for every distinct `scale` that is passed in (which in the dynamic shapes case, is some function of `query.size(-1)`). The real fix would be to ensure that the flash attention (and related) ops all accept a symbolic version of the `scale`. I'm not sure if we should use `SymFloat` or `Scalar` though - more discussion in the issue Pull Request resolved: pytorch#141725 Approved by: https://github.com/ezyang
Fixes the silent correctness for SDPA in pytorch#141710 Pull Request resolved: pytorch#141728 Approved by: https://github.com/Skylion007, https://github.com/ezyang, https://github.com/drisspg ghstack dependencies: pytorch#141725
Fixes pytorch#142076. Under compile, functional collectives are supposed to **not** return `AsyncCollectiveTensor`, and instead immediately issue calls to `wait_tensor()` (that we rely on the compiler to reorder as necessary. This is done with a function `_are_we_tracing()`, that tries to detect if we are running from inside of the compiler. One of the checks it performs is `is_torchdynamo_compiling()` ([here](https://github.com/pytorch/pytorch/blob/main/torch/distributed/_functional_collectives.py#L808C8-L808C34)). Unfortunately, this will always return False, even if dynamo is indeed tracing. The problem is that this function only returns true if dynamo **intercepts** the bytecode for `is_torchdynamo_compiling()`. However, this function is called during fake-tensor propagation, which is run as part of dynamo, but is not actually intercepted by dynamo itself. One thing that we know is the case during dynamo tracing, however, is that a `FakeTensorMode` is active. So I tweaked the logic to assume that we are tracing if there is an active fake mode. This could potentially have consequences for anybody running functional collectives with a fake mode directly, without compile in the loop. Although hopefully it's not too unreasonable to issue wait() calls immediately if you are running with fake tensor (presumably you only care about fake tensor propagation, in which case the wait() calls should technically be a no-op). Pull Request resolved: pytorch#142075 Approved by: https://github.com/yifuwang, https://github.com/kwen2501 ghstack dependencies: pytorch#141725, pytorch#141728
Fixes #141710. Previously, if we called flash attention with a
SymFloatscale that was properly symbolic, we would unsafely cast its rawSymFloat._datainto afloat, which is pretty much guaranteed to giveNaN.This avoids the NaNs in the linked issue, but I'm not sure if it's worth landing yet because we'll start specializing and recompiling for every distinct
scalethat is passed in (which in the dynamic shapes case, is some function ofquery.size(-1)).The real fix would be to ensure that the flash attention (and related) ops all accept a symbolic version of the
scale. I'm not sure if we should useSymFloatorScalarthough - more discussion in the issueStack from ghstack (oldest at bottom):
cc @jgong5 @mingfeima @XiaobingSuper @sanchitintel @ashokei @jingxu10 @voznesenskym @penguinwu @EikanWang @Guobing-Chen @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @chenyang78 @kadeng @chauhang @amjames