KEMBAR78
AsyncCollectiveTensor: fix _are_we_tracing() in dynamo by bdhirsh · Pull Request #142075 · pytorch/pytorch · GitHub
Skip to content

Conversation

@bdhirsh
Copy link
Contributor

@bdhirsh bdhirsh commented Dec 4, 2024

Fixes #142076. Under compile, functional collectives are supposed to not return AsyncCollectiveTensor, and instead immediately issue calls to wait_tensor() (that we rely on the compiler to reorder as necessary.

This is done with a function _are_we_tracing(), that tries to detect if we are running from inside of the compiler. One of the checks it performs is is_torchdynamo_compiling() (here).

Unfortunately, this will always return False, even if dynamo is indeed tracing. The problem is that this function only returns true if dynamo intercepts the bytecode for is_torchdynamo_compiling(). However, this function is called during fake-tensor propagation, which is run as part of dynamo, but is not actually intercepted by dynamo itself.

One thing that we know is the case during dynamo tracing, however, is that a FakeTensorMode is active. So I tweaked the logic to assume that we are tracing if there is an active fake mode.

This could potentially have consequences for anybody running functional collectives with a fake mode directly, without compile in the loop. Although hopefully it's not too unreasonable to issue wait() calls immediately if you are running with fake tensor (presumably you only care about fake tensor propagation, in which case the wait() calls should technically be a no-op).

Stack from ghstack (oldest at bottom):

cc @H-Huang @awgu @kwen2501 @wanchaol @fegin @fduwjj @wz337 @wconstab @d4l3k @c-p-i-o

@pytorch-bot
Copy link

pytorch-bot bot commented Dec 4, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/142075

Note: Links to docs will display an error until the docs builds have been completed.

✅ You can merge normally! (2 Unrelated Failures)

As of commit a358dc7 with merge base 9125e91 (image):

UNSTABLE - The following jobs failed but were likely due to flakiness present on trunk and has been marked as unstable:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@pytorch-bot pytorch-bot bot added ciflow/inductor oncall: distributed Add this issue/PR to distributed oncall triage queue labels Dec 4, 2024
bdhirsh added a commit that referenced this pull request Dec 4, 2024
@kwen2501
Copy link
Contributor

kwen2501 commented Dec 4, 2024

Under compile, functional collectives are supposed to not return AsyncCollectiveTensor, and instead immediately issue calls to wait_tensor() (that we rely on the compiler to reorder as necessary.

I was wondering why the eager-mode behavior of DTensor cannot be like this.

@bdhirsh bdhirsh added the release notes: distributed (dtensor) release notes category label Dec 4, 2024
Copy link
Collaborator

@yifuwang yifuwang left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM!

It seems that torch.compiler.is_compiling() should be a more comprehensive check that can be used here. However, it doesn't seem like anything aside from export is setting the underlying flag?

@albanD albanD removed their request for review December 4, 2024 22:20
@bdhirsh
Copy link
Contributor Author

bdhirsh commented Dec 4, 2024

Yep, it looks like... torch.compiler.is_compiling() won't quite work either in this possible case.

I think the main issue is that is_compiler() works properly in "user" code today (code that dynamo executes bytecode for), but does not work as much in framework code that we need to trace. Checking for an active FakeTensorMode is probably the most sure way of doing this today. It's probably a conversation for whether or not we want to add that check to is_compiling(), so for now I'll leave the custom logic in ACT.

@bdhirsh
Copy link
Contributor Author

bdhirsh commented Dec 4, 2024

I was wondering why the eager-mode behavior of DTensor cannot be like this.

hmm i might be misunderstanding your question, but eager DTensor doesn't do this because that would result in eager DTensor usage not overlapping and compute and comms. Under compile, we force functional collectives to be synchronous because we can rely on the compiler to reorder the wait_tensor() calls later as necessary for good perf

@kwen2501
Copy link
Contributor

kwen2501 commented Dec 4, 2024

I was thinking, if people need good perf, they should turn on torch.compile.
And hence DTensor op can have the same semantics as regular tensor op -- when an op returns, the result is observable.

Nevertheless, thanks for the fix!

@yifuwang
Copy link
Collaborator

yifuwang commented Dec 4, 2024

@bdhirsh oh I was referring to this (instead of is_dynamo_compiling):
https://github.com/pytorch/pytorch/blob/main/torch/compiler/__init__.py#L367

image

The documentation indicates that it should cover non-byte code cases as well. But I don't think anything aside from export is setting _is_compiling_flag.

@bdhirsh
Copy link
Contributor Author

bdhirsh commented Dec 4, 2024

@yifuwang yeah that is fair. There are probably two ways we could fix is_compiling() so we can use it in ACT:

(1) check if a fake tensor mode is active. I'm not actually sure if this is as ok to do: FakeTensorMode is user API, so someone could find it unintuitive for is_compiling() to return true under (eager) FakeTensorMode

(2) do something similar to what is_compiling() does for export in that snippet: flip a global bool on entry/exit to dynamo. We could probably do this, although it does mean (a bit of) extra work in the hot-path of every compiled graph invocation.

I'll probably leave that discussion as a followup / separate from this PR

@bdhirsh
Copy link
Contributor Author

bdhirsh commented Dec 4, 2024

@pytorchbot merge

@pytorch-bot pytorch-bot bot added the ciflow/trunk Trigger trunk jobs on your pull request label Dec 4, 2024
@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

@yifuwang
Copy link
Collaborator

yifuwang commented Dec 4, 2024

Yeah this change is good. I was just curious if there's an API that could work.

@wconstab
Copy link
Contributor

wconstab commented Dec 4, 2024

I think this sgtm too. The only thing I want to check is, for the AutoFSDP work we model compute/comm time based on a FakeTensor trace. Will this be impacted by the change? I'm not sure. cc @weifengpy

@weifengpy
Copy link
Contributor

I think this sgtm too. The only thing I want to check is, for the AutoFSDP work we model compute/comm time based on a FakeTensor trace. Will this be impacted by the change? I'm not sure. cc @weifengpy

should be good. for 1d fsdp, we trace vanilla modules (no collectives) and decide module wrapping policy

pobin6 pushed a commit to pobin6/pytorch that referenced this pull request Dec 5, 2024
Fixes pytorch#142076. Under compile, functional collectives are supposed to **not** return `AsyncCollectiveTensor`, and instead immediately issue calls to `wait_tensor()` (that we rely on the compiler to reorder as necessary.

This is done with a function `_are_we_tracing()`, that tries to detect if we are running from inside of the compiler. One of the checks it performs is `is_torchdynamo_compiling()` ([here](https://github.com/pytorch/pytorch/blob/main/torch/distributed/_functional_collectives.py#L808C8-L808C34)).

Unfortunately, this will always return False, even if dynamo is indeed tracing. The problem is that this function only returns true if dynamo **intercepts** the bytecode for `is_torchdynamo_compiling()`. However, this function is called during fake-tensor propagation, which is run as part of dynamo, but is not actually intercepted by dynamo itself.

One thing that we know is the case during dynamo tracing, however, is that a `FakeTensorMode` is active. So I tweaked the logic to assume that we are tracing if there is an active fake mode.

This could potentially have consequences for anybody running functional collectives with a fake mode directly, without compile in the loop. Although hopefully it's not too unreasonable to issue wait() calls immediately if you are running with fake tensor (presumably you only care about fake tensor propagation, in which case the wait() calls should technically be a no-op).

Pull Request resolved: pytorch#142075
Approved by: https://github.com/yifuwang, https://github.com/kwen2501
ghstack dependencies: pytorch#141725, pytorch#141728
pytorch-bot bot pushed a commit that referenced this pull request Dec 9, 2024
Fixes #142076. Under compile, functional collectives are supposed to **not** return `AsyncCollectiveTensor`, and instead immediately issue calls to `wait_tensor()` (that we rely on the compiler to reorder as necessary.

This is done with a function `_are_we_tracing()`, that tries to detect if we are running from inside of the compiler. One of the checks it performs is `is_torchdynamo_compiling()` ([here](https://github.com/pytorch/pytorch/blob/main/torch/distributed/_functional_collectives.py#L808C8-L808C34)).

Unfortunately, this will always return False, even if dynamo is indeed tracing. The problem is that this function only returns true if dynamo **intercepts** the bytecode for `is_torchdynamo_compiling()`. However, this function is called during fake-tensor propagation, which is run as part of dynamo, but is not actually intercepted by dynamo itself.

One thing that we know is the case during dynamo tracing, however, is that a `FakeTensorMode` is active. So I tweaked the logic to assume that we are tracing if there is an active fake mode.

This could potentially have consequences for anybody running functional collectives with a fake mode directly, without compile in the loop. Although hopefully it's not too unreasonable to issue wait() calls immediately if you are running with fake tensor (presumably you only care about fake tensor propagation, in which case the wait() calls should technically be a no-op).

Pull Request resolved: #142075
Approved by: https://github.com/yifuwang, https://github.com/kwen2501
ghstack dependencies: #141725, #141728
@github-actions github-actions bot deleted the gh/bdhirsh/633/head branch January 5, 2025 02:10
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ciflow/inductor ciflow/trunk Trigger trunk jobs on your pull request Merged oncall: distributed Add this issue/PR to distributed oncall triage queue release notes: distributed (dtensor) release notes category

Projects

None yet

Development

Successfully merging this pull request may close these issues.

6 participants