-
Notifications
You must be signed in to change notification settings - Fork 25.7k
preserve signatures with multiple calls + buffer mutations #138669
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/138669
Note: Links to docs will display an error until the docs builds have been completed. ✅ No FailuresAs of commit fe87858 with merge base fe458ee ( This comment was automatically generated by Dr. CI and updates every 15 minutes. |
|
This pull request was exported from Phabricator. Differential Revision: D64806175 |
0b4df53 to
5a7268d
Compare
|
This pull request was exported from Phabricator. Differential Revision: D64806175 |
1 similar comment
|
This pull request was exported from Phabricator. Differential Revision: D64806175 |
851d564 to
0ac21b6
Compare
|
This pull request was exported from Phabricator. Differential Revision: D64806175 |
1 similar comment
|
This pull request was exported from Phabricator. Differential Revision: D64806175 |
0ac21b6 to
6747141
Compare
test/export/test_export.py
Outdated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you also test unflatten on training_ir directly?
torch/export/exported_program.py
Outdated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What does "from_node" mean actually? Shouldn't we also need to rewrite the node.meta after run_decomp to reflect the change in "from_node"?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
"from_node" keeps a history of how a node was generated from tracing other nodes.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
So after the first decomp, it will contain the original node from export; after the second decomp, it will contain the original node followed by the node in the first decomp; etc.
torch/export/unflatten.py
Outdated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hmm this won't work for training IR cause because we don't have this information...
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
And fixing this is probably complicated, what do you think about temporarily decomposing the training IR to figure out which buffers are mutated in the short term?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Note that this is only needed when buffer updates have been functionalized away. If the code contains direct buffer updates, then none of this is required. So it should be fine.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hmm, I was wrong. It turns out that the output node, rather than the input node, of a mutation is what is threaded through the rest of the program, so every buffer mutation won't have the placeholder node corresponding to the buffer. E.g., add_ = buf.add_(1); add__1 = add_.add_(2).
I think temporarily decomposing the training IR is overkill. How can I detect mutating ops?
6747141 to
969dfa9
Compare
|
This pull request was exported from Phabricator. Differential Revision: D64806175 |
1 similar comment
|
This pull request was exported from Phabricator. Differential Revision: D64806175 |
f071676 to
94caa3b
Compare
|
This pull request was exported from Phabricator. Differential Revision: D64806175 |
…38669) Summary: Pull Request resolved: pytorch#138669 Test Plan: modified test Differential Revision: D64806175
|
This pull request was exported from Phabricator. Differential Revision: D64806175 |
94caa3b to
fe87858
Compare
|
@pytorchbot merge (Initiating merge automatically since Phabricator Diff has merged) |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
|
@pytorchbot merge -f "Landed internally" |
|
The merge job was canceled or timed out. This most often happen if two merge requests were issued for the same PR, or if merge job was waiting for more than 6 hours for tests to finish. In later case, please do not hesitate to reissue the merge command |
Merge startedYour change will be merged immediately since you used the force (-f) flag, bypassing any CI checks (ETA: 1-5 minutes). Please use Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
As called out in #137999, preserving signatures of multiple calls when buffer mutations are present was NYI. The main problem was that intermediate values of buffers were not tracked, so couldn't be propagated statefully between multiple calls (i.e., they would need to be explicitly passed around, defeating the unlifting needed for preserving signatures).
This PR fixes this situation, by introducing module attributes that carry the necessary intermediate values of buffer mutations. In general, a buffer mutation can have several intermediate values it depends on recursively, even other buffers. So rather than tying an intermediate value with a particular buffer, we tie it with the submodules that create and read it. We install an attribute on all modules that create or read a particular intermediate value, sharing the same initial storage (i.e., initialized with the same empty tensor). For the module that creates this intermediate value, we copy the value into the corresponding attribute; and for the modules that read it, we read the corresponding attribute instead.
Another complication that needed to be addressed was that a
run_decompositionsfollowing anexport_for_trainingwas not preserving module call graphs, which is needed for unflattening and, in particular, used when remapping inputs. Fortunately some existing metadata already tracks provenance of nodes, which we could use to update a module call graph after functionalization / decomposition.Differential Revision: D64806175