KEMBAR78
guard on flash attention SymFloat scale instead of incorrectly casting to float by bdhirsh · Pull Request #141725 · pytorch/pytorch · GitHub
Skip to content

Conversation

@bdhirsh
Copy link
Contributor

@bdhirsh bdhirsh commented Nov 27, 2024

Fixes #141710. Previously, if we called flash attention with a SymFloat scale that was properly symbolic, we would unsafely cast its raw SymFloat._data into a float, which is pretty much guaranteed to give NaN.

This avoids the NaNs in the linked issue, but I'm not sure if it's worth landing yet because we'll start specializing and recompiling for every distinct scale that is passed in (which in the dynamic shapes case, is some function of query.size(-1)).

The real fix would be to ensure that the flash attention (and related) ops all accept a symbolic version of the scale. I'm not sure if we should use SymFloat or Scalar though - more discussion in the issue

Stack from ghstack (oldest at bottom):

cc @jgong5 @mingfeima @XiaobingSuper @sanchitintel @ashokei @jingxu10 @voznesenskym @penguinwu @EikanWang @Guobing-Chen @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @chenyang78 @kadeng @chauhang @amjames

[ghstack-poisoned]
@pytorch-bot
Copy link

pytorch-bot bot commented Nov 27, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/141725

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit da90999 with merge base 9125e91 (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

Copy link
Contributor

@ezyang ezyang left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Use FILE and LINE

Specializing is better than silently wrong. Time to audit all other sites of this...

Need test.

@ezyang
Copy link
Contributor

ezyang commented Nov 28, 2024

There are a lot of other sites that need to be fixed.

aten/src/ATen/native/cpu/FlashAttentionKernel.cpp:      sdp::calculate_scale(query, scale).as_float_unchecked();
aten/src/ATen/native/cpu/FlashAttentionKernel.cpp:      sdp::calculate_scale(query, scale).as_float_unchecked();
aten/src/ATen/native/mps/operations/Attention.mm:  auto scale_factor = sdp::calculate_scale(query, scale).as_float_unchecked();
aten/src/ATen/native/transformers/attention.cpp:            query_padded, key_padded, value_padded, dropout_p, is_causal, false /*return_debug_mask*/, og_scale.as_float_unchecked());
aten/src/ATen/native/transformers/cuda/attention.cu:  const auto softmax_scale = sdp::calculate_scale(query, scale).as_float_unchecked();
aten/src/ATen/native/transformers/cuda/attention.cu:      sdp::calculate_scale(query, scale).as_float_unchecked();
aten/src/ATen/native/transformers/cuda/attention.cu:  const auto softmax_scale = sdp::calculate_scale(query, scale).as_float_unchecked();
aten/src/ATen/native/transformers/cuda/attention.cu:    p.scale = sdp::calculate_scale(query, scale).as_float_unchecked();
aten/src/ATen/native/transformers/cuda/attention_backward.cu:  const auto softmax_scale = sdp::calculate_scale(query, scale).as_float_unchecked();
aten/src/ATen/native/transformers/cuda/attention_backward.cu:    const auto softmax_scale = sdp::calculate_scale(query, scale).as_float_unchecked();
aten/src/ATen/native/transformers/cuda/attention_backward.cu:  const auto softmax_scale = sdp::calculate_scale(query, scale).as_float_unchecked();
aten/src/ATen/native/transformers/cuda/attention_backward.cu:    p.scale = sdp::calculate_scale(query, scale).as_float_unchecked();

cc @drisspg

@ezyang ezyang requested a review from drisspg November 28, 2024 01:27
@bdhirsh
Copy link
Contributor Author

bdhirsh commented Nov 28, 2024

I can audit the sites as part of this PR, thanks for the links. And landing specialization first sounds good

@drisspg
Copy link
Contributor

drisspg commented Nov 28, 2024

For my understanding, if the size is a true float and not symbolic this returns the correct result?

It is much less common in a user program (that isn't doing benchmarking) to vert the head dim size, which is way this hasn't shown up earlier

@bdhirsh
Copy link
Contributor Author

bdhirsh commented Nov 28, 2024

For my understanding, if the size is a true float and not symbolic this returns the correct result?

yup that's right - that basically maps to this code: https://github.com/pytorch/pytorch/blob/main/c10/core/SymFloat.h#L41 (it is only safe to treat _data as a float if !is_symbolic())

@bdhirsh bdhirsh changed the title [not for land] guard on flash attention SymFloat scale instead of incorrectly casting to float guard on flash attention SymFloat scale instead of incorrectly casting to float Nov 28, 2024
@bdhirsh
Copy link
Contributor Author

bdhirsh commented Nov 28, 2024

Looks like all of the other sites are inside of kernels (registered to backend dispatch keys), so as_float_unchecked() should be ok

@ezyang
Copy link
Contributor

ezyang commented Dec 2, 2024

Looks like all of the other sites are inside of kernels (registered to backend dispatch keys), so as_float_unchecked() should be ok

Nah, we should use the guard in that case. I don't want as_float_unchecked except in core framework code.

@albanD albanD removed their request for review December 2, 2024 19:44
@bdhirsh
Copy link
Contributor Author

bdhirsh commented Dec 2, 2024

sgtm, will update. (I'll try to land this PR tomorrow)

…ctly casting to float"

Fixes #141710. Previously, if we called flash attention with a `SymFloat` scale that was properly symbolic, we would unsafely cast its raw `SymFloat._data` into a `float`, which is pretty much guaranteed to give `NaN`.

This avoids the NaNs in the linked issue, but I'm not sure if it's worth landing yet because we'll start specializing and recompiling for every distinct `scale` that is passed in (which in the dynamic shapes case, is some function of `query.size(-1)`).

The real fix would be to ensure that the flash attention (and related) ops all accept a symbolic version of the `scale`. I'm not sure if we should use `SymFloat` or `Scalar` though - more discussion in the issue





[ghstack-poisoned]
@pytorch-bot pytorch-bot bot added ciflow/inductor ciflow/mps Run MPS tests (subset of trunk) module: cpu CPU specific problem (e.g., perf, algorithm) module: dynamo release notes: mps Release notes category labels Dec 3, 2024
@bdhirsh
Copy link
Contributor Author

bdhirsh commented Dec 3, 2024

Removed the as_float_unchecked calls in the attention kernels, added tests

@bdhirsh bdhirsh added the ciflow/trunk Trigger trunk jobs on your pull request label Dec 3, 2024
…ctly casting to float"

Fixes #141710. Previously, if we called flash attention with a `SymFloat` scale that was properly symbolic, we would unsafely cast its raw `SymFloat._data` into a `float`, which is pretty much guaranteed to give `NaN`.

This avoids the NaNs in the linked issue, but I'm not sure if it's worth landing yet because we'll start specializing and recompiling for every distinct `scale` that is passed in (which in the dynamic shapes case, is some function of `query.size(-1)`).

The real fix would be to ensure that the flash attention (and related) ops all accept a symbolic version of the `scale`. I'm not sure if we should use `SymFloat` or `Scalar` though - more discussion in the issue





cc jgong5 mingfeima XiaobingSuper sanchitintel ashokei jingxu10 voznesenskym penguinwu EikanWang Guobing-Chen zhuhaozhe blzheng wenzhe-nrv jiayisunx chenyang78 kadeng chauhang amjames

[ghstack-poisoned]
pytorchmergebot pushed a commit that referenced this pull request Dec 3, 2024
Fixes the silent correctness for SDPA in #141710

Pull Request resolved: #141728
Approved by: https://github.com/Skylion007, https://github.com/ezyang, https://github.com/drisspg
ghstack dependencies: #141725
pytorchmergebot pushed a commit that referenced this pull request Dec 5, 2024
Fixes #142076. Under compile, functional collectives are supposed to **not** return `AsyncCollectiveTensor`, and instead immediately issue calls to `wait_tensor()` (that we rely on the compiler to reorder as necessary.

This is done with a function `_are_we_tracing()`, that tries to detect if we are running from inside of the compiler. One of the checks it performs is `is_torchdynamo_compiling()` ([here](https://github.com/pytorch/pytorch/blob/main/torch/distributed/_functional_collectives.py#L808C8-L808C34)).

Unfortunately, this will always return False, even if dynamo is indeed tracing. The problem is that this function only returns true if dynamo **intercepts** the bytecode for `is_torchdynamo_compiling()`. However, this function is called during fake-tensor propagation, which is run as part of dynamo, but is not actually intercepted by dynamo itself.

One thing that we know is the case during dynamo tracing, however, is that a `FakeTensorMode` is active. So I tweaked the logic to assume that we are tracing if there is an active fake mode.

This could potentially have consequences for anybody running functional collectives with a fake mode directly, without compile in the loop. Although hopefully it's not too unreasonable to issue wait() calls immediately if you are running with fake tensor (presumably you only care about fake tensor propagation, in which case the wait() calls should technically be a no-op).

Pull Request resolved: #142075
Approved by: https://github.com/yifuwang, https://github.com/kwen2501
ghstack dependencies: #141725, #141728
pobin6 pushed a commit to pobin6/pytorch that referenced this pull request Dec 5, 2024
…g to float (pytorch#141725)

Fixes pytorch#141710. Previously, if we called flash attention with a `SymFloat` scale that was properly symbolic, we would unsafely cast its raw `SymFloat._data` into a `float`, which is pretty much guaranteed to give `NaN`.

This avoids the NaNs in the linked issue, but I'm not sure if it's worth landing yet because we'll start specializing and recompiling for every distinct `scale` that is passed in (which in the dynamic shapes case, is some function of `query.size(-1)`).

The real fix would be to ensure that the flash attention (and related) ops all accept a symbolic version of the `scale`. I'm not sure if we should use `SymFloat` or `Scalar` though - more discussion in the issue

Pull Request resolved: pytorch#141725
Approved by: https://github.com/ezyang
pobin6 pushed a commit to pobin6/pytorch that referenced this pull request Dec 5, 2024
pobin6 pushed a commit to pobin6/pytorch that referenced this pull request Dec 5, 2024
Fixes pytorch#142076. Under compile, functional collectives are supposed to **not** return `AsyncCollectiveTensor`, and instead immediately issue calls to `wait_tensor()` (that we rely on the compiler to reorder as necessary.

This is done with a function `_are_we_tracing()`, that tries to detect if we are running from inside of the compiler. One of the checks it performs is `is_torchdynamo_compiling()` ([here](https://github.com/pytorch/pytorch/blob/main/torch/distributed/_functional_collectives.py#L808C8-L808C34)).

Unfortunately, this will always return False, even if dynamo is indeed tracing. The problem is that this function only returns true if dynamo **intercepts** the bytecode for `is_torchdynamo_compiling()`. However, this function is called during fake-tensor propagation, which is run as part of dynamo, but is not actually intercepted by dynamo itself.

One thing that we know is the case during dynamo tracing, however, is that a `FakeTensorMode` is active. So I tweaked the logic to assume that we are tracing if there is an active fake mode.

This could potentially have consequences for anybody running functional collectives with a fake mode directly, without compile in the loop. Although hopefully it's not too unreasonable to issue wait() calls immediately if you are running with fake tensor (presumably you only care about fake tensor propagation, in which case the wait() calls should technically be a no-op).

Pull Request resolved: pytorch#142075
Approved by: https://github.com/yifuwang, https://github.com/kwen2501
ghstack dependencies: pytorch#141725, pytorch#141728
AmdSampsa pushed a commit to AmdSampsa/pytorch that referenced this pull request Dec 9, 2024
…g to float (pytorch#141725)

Fixes pytorch#141710. Previously, if we called flash attention with a `SymFloat` scale that was properly symbolic, we would unsafely cast its raw `SymFloat._data` into a `float`, which is pretty much guaranteed to give `NaN`.

This avoids the NaNs in the linked issue, but I'm not sure if it's worth landing yet because we'll start specializing and recompiling for every distinct `scale` that is passed in (which in the dynamic shapes case, is some function of `query.size(-1)`).

The real fix would be to ensure that the flash attention (and related) ops all accept a symbolic version of the `scale`. I'm not sure if we should use `SymFloat` or `Scalar` though - more discussion in the issue

Pull Request resolved: pytorch#141725
Approved by: https://github.com/ezyang
AmdSampsa pushed a commit to AmdSampsa/pytorch that referenced this pull request Dec 9, 2024
AmdSampsa pushed a commit to AmdSampsa/pytorch that referenced this pull request Dec 9, 2024
Fixes pytorch#142076. Under compile, functional collectives are supposed to **not** return `AsyncCollectiveTensor`, and instead immediately issue calls to `wait_tensor()` (that we rely on the compiler to reorder as necessary.

This is done with a function `_are_we_tracing()`, that tries to detect if we are running from inside of the compiler. One of the checks it performs is `is_torchdynamo_compiling()` ([here](https://github.com/pytorch/pytorch/blob/main/torch/distributed/_functional_collectives.py#L808C8-L808C34)).

Unfortunately, this will always return False, even if dynamo is indeed tracing. The problem is that this function only returns true if dynamo **intercepts** the bytecode for `is_torchdynamo_compiling()`. However, this function is called during fake-tensor propagation, which is run as part of dynamo, but is not actually intercepted by dynamo itself.

One thing that we know is the case during dynamo tracing, however, is that a `FakeTensorMode` is active. So I tweaked the logic to assume that we are tracing if there is an active fake mode.

This could potentially have consequences for anybody running functional collectives with a fake mode directly, without compile in the loop. Although hopefully it's not too unreasonable to issue wait() calls immediately if you are running with fake tensor (presumably you only care about fake tensor propagation, in which case the wait() calls should technically be a no-op).

Pull Request resolved: pytorch#142075
Approved by: https://github.com/yifuwang, https://github.com/kwen2501
ghstack dependencies: pytorch#141725, pytorch#141728
@github-actions github-actions bot deleted the gh/bdhirsh/631/head branch January 3, 2025 02:08
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ciflow/inductor ciflow/mps Run MPS tests (subset of trunk) ciflow/trunk Trigger trunk jobs on your pull request Merged module: cpu CPU specific problem (e.g., perf, algorithm) module: dynamo release notes: mps Release notes category

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants