-
Notifications
You must be signed in to change notification settings - Fork 25.7k
AsyncCollectiveTensor: fix _are_we_tracing() in dynamo #142075
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
[ghstack-poisoned]
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/142075
Note: Links to docs will display an error until the docs builds have been completed. ✅ You can merge normally! (2 Unrelated Failures)As of commit a358dc7 with merge base 9125e91 ( UNSTABLE - The following jobs failed but were likely due to flakiness present on trunk and has been marked as unstable:
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
I was wondering why the eager-mode behavior of DTensor cannot be like this. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM!
It seems that torch.compiler.is_compiling() should be a more comprehensive check that can be used here. However, it doesn't seem like anything aside from export is setting the underlying flag?
|
Yep, it looks like... I think the main issue is that |
hmm i might be misunderstanding your question, but eager DTensor doesn't do this because that would result in eager DTensor usage not overlapping and compute and comms. Under compile, we force functional collectives to be synchronous because we can rely on the compiler to reorder the wait_tensor() calls later as necessary for good perf |
|
I was thinking, if people need good perf, they should turn on torch.compile. Nevertheless, thanks for the fix! |
|
@bdhirsh oh I was referring to this (instead of
The documentation indicates that it should cover non-byte code cases as well. But I don't think anything aside from export is setting |
|
@yifuwang yeah that is fair. There are probably two ways we could fix (1) check if a fake tensor mode is active. I'm not actually sure if this is as ok to do: FakeTensorMode is user API, so someone could find it unintuitive for (2) do something similar to what is_compiling() does for export in that snippet: flip a global bool on entry/exit to dynamo. We could probably do this, although it does mean (a bit of) extra work in the hot-path of every compiled graph invocation. I'll probably leave that discussion as a followup / separate from this PR |
|
@pytorchbot merge |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
|
Yeah this change is good. I was just curious if there's an API that could work. |
|
I think this sgtm too. The only thing I want to check is, for the AutoFSDP work we model compute/comm time based on a FakeTensor trace. Will this be impacted by the change? I'm not sure. cc @weifengpy |
should be good. for 1d fsdp, we trace vanilla modules (no collectives) and decide module wrapping policy |
Fixes pytorch#142076. Under compile, functional collectives are supposed to **not** return `AsyncCollectiveTensor`, and instead immediately issue calls to `wait_tensor()` (that we rely on the compiler to reorder as necessary. This is done with a function `_are_we_tracing()`, that tries to detect if we are running from inside of the compiler. One of the checks it performs is `is_torchdynamo_compiling()` ([here](https://github.com/pytorch/pytorch/blob/main/torch/distributed/_functional_collectives.py#L808C8-L808C34)). Unfortunately, this will always return False, even if dynamo is indeed tracing. The problem is that this function only returns true if dynamo **intercepts** the bytecode for `is_torchdynamo_compiling()`. However, this function is called during fake-tensor propagation, which is run as part of dynamo, but is not actually intercepted by dynamo itself. One thing that we know is the case during dynamo tracing, however, is that a `FakeTensorMode` is active. So I tweaked the logic to assume that we are tracing if there is an active fake mode. This could potentially have consequences for anybody running functional collectives with a fake mode directly, without compile in the loop. Although hopefully it's not too unreasonable to issue wait() calls immediately if you are running with fake tensor (presumably you only care about fake tensor propagation, in which case the wait() calls should technically be a no-op). Pull Request resolved: pytorch#142075 Approved by: https://github.com/yifuwang, https://github.com/kwen2501 ghstack dependencies: pytorch#141725, pytorch#141728
Fixes #142076. Under compile, functional collectives are supposed to **not** return `AsyncCollectiveTensor`, and instead immediately issue calls to `wait_tensor()` (that we rely on the compiler to reorder as necessary. This is done with a function `_are_we_tracing()`, that tries to detect if we are running from inside of the compiler. One of the checks it performs is `is_torchdynamo_compiling()` ([here](https://github.com/pytorch/pytorch/blob/main/torch/distributed/_functional_collectives.py#L808C8-L808C34)). Unfortunately, this will always return False, even if dynamo is indeed tracing. The problem is that this function only returns true if dynamo **intercepts** the bytecode for `is_torchdynamo_compiling()`. However, this function is called during fake-tensor propagation, which is run as part of dynamo, but is not actually intercepted by dynamo itself. One thing that we know is the case during dynamo tracing, however, is that a `FakeTensorMode` is active. So I tweaked the logic to assume that we are tracing if there is an active fake mode. This could potentially have consequences for anybody running functional collectives with a fake mode directly, without compile in the loop. Although hopefully it's not too unreasonable to issue wait() calls immediately if you are running with fake tensor (presumably you only care about fake tensor propagation, in which case the wait() calls should technically be a no-op). Pull Request resolved: #142075 Approved by: https://github.com/yifuwang, https://github.com/kwen2501 ghstack dependencies: #141725, #141728

Fixes #142076. Under compile, functional collectives are supposed to not return
AsyncCollectiveTensor, and instead immediately issue calls towait_tensor()(that we rely on the compiler to reorder as necessary.This is done with a function
_are_we_tracing(), that tries to detect if we are running from inside of the compiler. One of the checks it performs isis_torchdynamo_compiling()(here).Unfortunately, this will always return False, even if dynamo is indeed tracing. The problem is that this function only returns true if dynamo intercepts the bytecode for
is_torchdynamo_compiling(). However, this function is called during fake-tensor propagation, which is run as part of dynamo, but is not actually intercepted by dynamo itself.One thing that we know is the case during dynamo tracing, however, is that a
FakeTensorModeis active. So I tweaked the logic to assume that we are tracing if there is an active fake mode.This could potentially have consequences for anybody running functional collectives with a fake mode directly, without compile in the loop. Although hopefully it's not too unreasonable to issue wait() calls immediately if you are running with fake tensor (presumably you only care about fake tensor propagation, in which case the wait() calls should technically be a no-op).
Stack from ghstack (oldest at bottom):
cc @H-Huang @awgu @kwen2501 @wanchaol @fegin @fduwjj @wz337 @wconstab @d4l3k @c-p-i-o