KEMBAR78
preserve signatures with multiple calls + buffer mutations by avikchaudhuri · Pull Request #138669 · pytorch/pytorch · GitHub
Skip to content

Conversation

@avikchaudhuri
Copy link
Contributor

@avikchaudhuri avikchaudhuri commented Oct 23, 2024

As called out in #137999, preserving signatures of multiple calls when buffer mutations are present was NYI. The main problem was that intermediate values of buffers were not tracked, so couldn't be propagated statefully between multiple calls (i.e., they would need to be explicitly passed around, defeating the unlifting needed for preserving signatures).

This PR fixes this situation, by introducing module attributes that carry the necessary intermediate values of buffer mutations. In general, a buffer mutation can have several intermediate values it depends on recursively, even other buffers. So rather than tying an intermediate value with a particular buffer, we tie it with the submodules that create and read it. We install an attribute on all modules that create or read a particular intermediate value, sharing the same initial storage (i.e., initialized with the same empty tensor). For the module that creates this intermediate value, we copy the value into the corresponding attribute; and for the modules that read it, we read the corresponding attribute instead.

Another complication that needed to be addressed was that a run_decompositions following an export_for_training was not preserving module call graphs, which is needed for unflattening and, in particular, used when remapping inputs. Fortunately some existing metadata already tracks provenance of nodes, which we could use to update a module call graph after functionalization / decomposition.

Differential Revision: D64806175

@pytorch-bot
Copy link

pytorch-bot bot commented Oct 23, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/138669

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit fe87858 with merge base fe458ee (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D64806175

@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D64806175

1 similar comment
@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D64806175

@avikchaudhuri avikchaudhuri force-pushed the export-D64806175 branch 2 times, most recently from 851d564 to 0ac21b6 Compare October 23, 2024 19:37
@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D64806175

1 similar comment
@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D64806175

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you also test unflatten on training_ir directly?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What does "from_node" mean actually? Shouldn't we also need to rewrite the node.meta after run_decomp to reflect the change in "from_node"?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

"from_node" keeps a history of how a node was generated from tracing other nodes.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So after the first decomp, it will contain the original node from export; after the second decomp, it will contain the original node followed by the node in the first decomp; etc.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm this won't work for training IR cause because we don't have this information...

Copy link
Contributor

@tugsbayasgalan tugsbayasgalan Oct 23, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

And fixing this is probably complicated, what do you think about temporarily decomposing the training IR to figure out which buffers are mutated in the short term?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note that this is only needed when buffer updates have been functionalized away. If the code contains direct buffer updates, then none of this is required. So it should be fine.

Copy link
Contributor Author

@avikchaudhuri avikchaudhuri Oct 23, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm, I was wrong. It turns out that the output node, rather than the input node, of a mutation is what is threaded through the rest of the program, so every buffer mutation won't have the placeholder node corresponding to the buffer. E.g., add_ = buf.add_(1); add__1 = add_.add_(2).

I think temporarily decomposing the training IR is overkill. How can I detect mutating ops?

@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D64806175

1 similar comment
@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D64806175

@avikchaudhuri avikchaudhuri force-pushed the export-D64806175 branch 2 times, most recently from f071676 to 94caa3b Compare October 24, 2024 16:25
@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D64806175

…38669)

Summary: Pull Request resolved: pytorch#138669

Test Plan: modified test

Differential Revision: D64806175
@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D64806175

@pytorch-bot pytorch-bot bot added the ciflow/trunk Trigger trunk jobs on your pull request label Oct 24, 2024
@facebook-github-bot
Copy link
Contributor

@pytorchbot merge

(Initiating merge automatically since Phabricator Diff has merged)

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

@avikchaudhuri
Copy link
Contributor Author

@pytorchbot merge -f "Landed internally"

@pytorchmergebot
Copy link
Collaborator

The merge job was canceled or timed out. This most often happen if two merge requests were issued for the same PR, or if merge job was waiting for more than 6 hours for tests to finish. In later case, please do not hesitate to reissue the merge command
For more information see pytorch-bot wiki.

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged immediately since you used the force (-f) flag, bypassing any CI checks (ETA: 1-5 minutes). Please use -f as last resort and instead consider -i/--ignore-current to continue the merge ignoring current failures. This will allow currently pending tests to finish and report signal before the merge.

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants