KEMBAR78
preserve module signature with multiple calls by avikchaudhuri · Pull Request #137999 · pytorch/pytorch · GitHub
Skip to content

Conversation

@avikchaudhuri
Copy link
Contributor

@avikchaudhuri avikchaudhuri commented Oct 15, 2024

Previously we would error when trying to preserve the call signature for a module when it was called multiple times. This PR can now do this without erroring. The fix is to propagate call indices in a few more places.

Note that while this works in the presence of params, buffers, and tensor constants, preserving call signatures for multiple calls to a module when buffers are mutated is not supported yet. This is future work. The main problem is that we do not have enough metadata to copy_ mutated buffers at the end of each call to a module, so the next call can read those buffers at the beginning. Making this work will likely need some explicit tracking of intermediate values of mutated buffers when collecting metadata during functionalization in export.

Note also that we stop short of creating a single graph out of multiple graphs: that is still future work. So the unflattened module will still have different targets n, n@1, n@2, etc. for each call when we ask the module call signature of n to be preserved. However it is way easier to swap all of these targets with a replacement that behaves similar to the original, because all of these calls will respect the original module call signature. (In particular, any constant inputs will be carried by the calls.)

Differential Revision: D64406945

@pytorch-bot
Copy link

pytorch-bot bot commented Oct 15, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/137999

Note: Links to docs will display an error until the docs builds have been completed.

✅ You can merge normally! (2 Unrelated Failures)

As of commit 2755f65 with merge base 7a117f3 (image):

FLAKY - The following job failed but was likely due to flakiness present on trunk:

BROKEN TRUNK - The following job failed but was present on the merge base:

👉 Rebase onto the `viable/strict` branch to avoid these failures

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D64406945

@avikchaudhuri avikchaudhuri changed the title preserve module signature multiple calls preserve module signature with multiple calls Oct 15, 2024
@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D64406945

1 similar comment
@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D64406945

avikchaudhuri added a commit to avikchaudhuri/pytorch that referenced this pull request Oct 15, 2024
Summary: Pull Request resolved: pytorch#137999

Test Plan: fixed tests

Differential Revision: D64406945
self.parent_call_module: Optional[torch.fx.Node] = None
if parent is not None:
if self.fqn in module_call_graph and num_calls == 1:
raise ValueError(
Copy link
Contributor Author

@avikchaudhuri avikchaudhuri Oct 15, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

main thing that changed

if not is_retracebility_test(self._testMethodName):
test(
export(M(), inp, preserve_module_call_signature=("n",)),
swap={"n": N(), "n@1": N()},
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

easy swap

"Cannot unflatten multiple calls to module n while preserving its signature",
):
torch.export.unflatten(ep)
def test(ep, swap):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you move this test to test_unfatten.py? I want to make sure it works with training IR as well. This here tests training IR + run_decomp so it is still operating on a functional IR.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't like test_unflatten.py because in general there are more variations tested here. But good point that I should test this with state as well. I'll hold off landing until I do that / add them in.

For training IR only I could do a separate explicit test?

@pytorch-bot pytorch-bot bot added the ciflow/trunk Trigger trunk jobs on your pull request label Oct 16, 2024
@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D64406945

1 similar comment
@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D64406945

avikchaudhuri added a commit to avikchaudhuri/pytorch that referenced this pull request Oct 17, 2024
Summary: Pull Request resolved: pytorch#137999

Test Plan: fixed tests

Reviewed By: tugsbayasgalan

Differential Revision: D64406945
@avikchaudhuri avikchaudhuri force-pushed the export-D64406945 branch 2 times, most recently from cf754c2 to 9adf5c0 Compare October 17, 2024 19:49
@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D64406945

Summary: Pull Request resolved: pytorch#137999

Test Plan: fixed tests

Reviewed By: tugsbayasgalan

Differential Revision: D64406945
@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D64406945

@facebook-github-bot
Copy link
Contributor

@pytorchbot merge

(Initiating merge automatically since Phabricator Diff has merged)

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

@avikchaudhuri
Copy link
Contributor Author

@pytorchbot merge -f "Landed internally"

@pytorchmergebot
Copy link
Collaborator

The merge job was canceled or timed out. This most often happen if two merge requests were issued for the same PR, or if merge job was waiting for more than 6 hours for tests to finish. In later case, please do not hesitate to reissue the merge command
For more information see pytorch-bot wiki.

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged immediately since you used the force (-f) flag, bypassing any CI checks (ETA: 1-5 minutes). Please use -f as last resort and instead consider -i/--ignore-current to continue the merge ignoring current failures. This will allow currently pending tests to finish and report signal before the merge.

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

pytorchmergebot pushed a commit that referenced this pull request Oct 25, 2024
As called out in #137999, preserving signatures of multiple calls when buffer mutations are present was NYI. The main problem was that intermediate values of buffers were not tracked, so couldn't be propagated statefully between multiple calls (i.e., they would need to be explicitly passed around, defeating the unlifting needed for preserving signatures).

This PR fixes this situation, by introducing module attributes that carry the necessary intermediate values of buffer mutations. In general, a buffer mutation can have several intermediate values it depends on recursively, even other buffers. So rather than tying an intermediate value with a particular buffer, we tie it with the submodules that create and read it. We install an attribute on all modules that create or read a particular intermediate value, sharing the same initial storage (i.e., initialized with the same empty tensor). For the module that creates this intermediate value, we copy the value into the corresponding attribute; and for the modules that read it, we read the corresponding attribute instead.

Another complication that needed to be addressed was that a `run_decompositions` following an `export_for_training` was not preserving module call graphs, which is needed for unflattening and, in particular, used when remapping inputs. Fortunately some existing metadata already tracks provenance of nodes, which we could use to update a module call graph after functionalization / decomposition.

Differential Revision: D64806175

Pull Request resolved: #138669
Approved by: https://github.com/tugsbayasgalan
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants