KEMBAR78
unflatten with specialized graphs per submodule call by avikchaudhuri · Pull Request #137013 · pytorch/pytorch · GitHub
Skip to content

Conversation

@avikchaudhuri
Copy link
Contributor

@avikchaudhuri avikchaudhuri commented Sep 30, 2024

Previously we were making a fairly restrictive assumption when unflattening an exported program: for any submodule, we would assert that the graph of every call to that submodule must be the same. This assertion is load-bearing, i.e., if we simply remove the assertion then we can get incorrect results, as shown by the following example.

    class N(torch.nn.Module):
        def forward(self, x, b):
            if b:
                return x + 1
            else:
                return x + 2

    class M(torch.nn.Module):
        def __init__(self):
            super().__init__()
            self.n = N()

        def forward(self, x):
            x0 = x + 3
            x1 = self.n(x0, True)
            x2 = x1 + 4
            x3 = self.n(x2, False)
            return x3 + 5

    m = M()
    inp = (torch.ones(1),)
    print(m(*inp))  # tensor([16.])
    ep = torch.export.export(m, inp)
    print(ep.module()(*inp))  # tensor([16.])

    unflattened = torch.export.unflatten(ep)
    print(unflattened(*inp))  # tensor([15.])

However, this goes against the spirit of specializing graphs when exporting: we should expect that for every call to a submodule we might generate a different graph. The goal of this PR is to fix unflattening to handle multiple specialized graphs corresponding to multiple calls to the same submodule.

The idea is simple: for every call to a child module foo, we will create potentially different child modules foo, foo@1, foo@2, etc. and use those names as targets in callmodule instructions in the parent graph. An immediate consequence of this is that the list of fqns in an unflattened module may not be the same as an exported module. Note that all these variants share the same parameters / buffers, so that multiple calls to the same submodule can share state as expected.

However, as described so far this scheme may end up with needlessly too many submodules. Thus, between calls to the same submodule, if graphs are equal then we optimize away the extra submodules and reuse call names as much as possible. Moreover, when submodules are shared across fqns, we also try to de-duplicate graphs corresponding to their calls as much as possible. Note that no matter what, information about which submodule was called is still preserved, so that if a submodule has to be swapped with another, one can still find all calls to the former submodule and replace them with calls to the latter.

A note on the choice of naming scheme for call names: instead of generating "sibling" modules foo@1, foo@2, etc. for foo, we had considered generating "children" modules foo._1, foo._2, etc. of foo. However this can cause spurious cycles when de-duplicating graphs. E.g., suppose that foo is an alias for bar._1 and foo._1 is an alias for bar, then we must either introduce a cycle or drop the opportunity to optimize. Another idea would be to make foo a dummy module that contains foo._0 corresponding to the first call, but this necessitates too many changes to existing tests and hurts the common case.

Differential Revision: D63642479

cc @XilunWu @H-Huang @awgu @kwen2501 @wanchaol @fegin @fduwjj @wz337 @wconstab @d4l3k @c-p-i-o

@pytorch-bot
Copy link

pytorch-bot bot commented Sep 30, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/137013

Note: Links to docs will display an error until the docs builds have been completed.

✅ You can merge normally! (1 Unrelated Failure)

As of commit 964a4f1 with merge base 2b329d3 (image):

FLAKY - The following job failed but was likely due to flakiness present on trunk:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D63642479

@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D63642479

@pytorch-bot pytorch-bot bot added the oncall: distributed Add this issue/PR to distributed oncall triage queue label Sep 30, 2024
@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D63642479

@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D63642479

avikchaudhuri added a commit to avikchaudhuri/pytorch that referenced this pull request Sep 30, 2024
Summary: Pull Request resolved: pytorch#137013

Test Plan: added test

Differential Revision: D63642479

self.assertEqual(gm_flat_non_strict(*inp), gm_flat_strict(*inp))

def test_unflatten_multiple_graphs(self):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

One clarifying question: what would module swapping look like from the user side?

If they don't make distinctions between specialized graphs, it seems like they'd have to manually switch out n, n@1, and potentially p, p@1 if they're also aliasing? Or if they make distinctions for aliasing/specializations then some subset of those.

I don't have context for what swapping looks like today with aliasing - is it one or multiple swaps? - but the main point is, does this strictly introduce more work for swapping, and could we introduce some unflattener API for this?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, they will have to update all f@i variants to new f. Maybe we should build some convenience APIs to grab these fqn variants.

for k, seen_module in self.seen_modules[self.module_id][:-1]:
num_calls[k] = num_calls.get(k, 0) + 1
seen_child_fqn = _call_name(k, num_calls[k])
if _check_graph_equivalence(seen_module, self.module):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does this mean that differently specialized graphs (e.g. N(x, False), N(x, True)) won't share state? As in if we do attribute swaps on foo.bar, it won't have the same change for foo.bar@1 if the computation is different.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think they will share state because the same params / buffer objects have been assigned to all variants. See assign_attr.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The intended effect is that these variants are like different methods on the same "moral" instance.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh, I meant if someone were to modify params based on the original FQNs, like foo.bar.attr = foo.bar.attr.bfloat16(), foo.bar@1.attr won't see the same change? I've debugged some internal FSDP sharding pipelines that'll do this to modify parameter dtype at runtime.

I think this falls under the same category of convenience APIs though, so no big deal

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah I don't think it will see those changes unless they do it for all variants at the same time, so yeah, need that API. Any suggestions what that API should look like?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For module swapping we can probably just patch the logic into Angela's _swap_modules() method in unflatten.py? For attributes probably similar: def _swap_attributes(ep: ExportedProgram, attrs_to_swap: Dict[str, Union[Any, Callable[Any -> Any]]):

@pytorch-bot pytorch-bot bot added the ciflow/trunk Trigger trunk jobs on your pull request label Oct 1, 2024
@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D63642479

@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D63642479

avikchaudhuri added a commit to avikchaudhuri/pytorch that referenced this pull request Oct 2, 2024
Summary: Pull Request resolved: pytorch#137013

Test Plan: added test

Reviewed By: pianpwk

Differential Revision: D63642479
@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D63642479

1 similar comment
@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D63642479

@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D63642479

@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D63642479

1 similar comment
@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D63642479

@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D63642479

Summary: Pull Request resolved: pytorch#137013

Test Plan: added test

Reviewed By: pianpwk

Differential Revision: D63642479
@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D63642479

@facebook-github-bot
Copy link
Contributor

@pytorchbot merge -f 'Landed internally'

(Initiating merge automatically since Phabricator Diff has merged, using force because this PR might not pass merge_rules.json but landed internally)

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged immediately since you used the force (-f) flag, bypassing any CI checks (ETA: 1-5 minutes). Please use -f as last resort and instead consider -i/--ignore-current to continue the merge ignoring current failures. This will allow currently pending tests to finish and report signal before the merge.

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ciflow/trunk Trigger trunk jobs on your pull request fb-exported Merged oncall: distributed Add this issue/PR to distributed oncall triage queue release notes: export

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants