KEMBAR78
Set dropout in SDPA to 0.0 when not in training mode by ebsmothers · Pull Request #1803 · meta-pytorch/torchtune · GitHub
Skip to content

Conversation

ebsmothers
Copy link
Contributor

As pointed out by @zjost in #1791, we probably should just manually force dropout to be 0 outside of training. This is actually what we did originally, but along the way we switched to just always using the value from attn_dropout directly, which will not be correct at inference time if someone passes a nonzero value. So this PR changes back to coercing dropout to 0.0 outside of training mode. I checked with the author of the PR who first dropped the if/else dropout logic and it was not done by design.

I also make two other changes: (1) reverting the doc update I made last night, and (2) raising an error in case someone tries to use FlexAttention with nonzero dropout (it's not currently supported).

@ebsmothers ebsmothers requested a review from RdoubleA October 10, 2024 19:12
@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Oct 10, 2024
@pytorch-bot
Copy link

pytorch-bot bot commented Oct 10, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/torchtune/1803

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit eb9a317 with merge base 5de5001 (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@ebsmothers ebsmothers merged commit 665ab3f into meta-pytorch:main Oct 10, 2024
17 checks passed
@ebsmothers ebsmothers deleted the disable-sdpa-dropout-for-inference branch October 10, 2024 19:53
mori360 pushed a commit to mori360/torchtune that referenced this pull request Oct 14, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants