KEMBAR78
[inductor] FX graph cache: Add support for symbolic shapes by masnesral · Pull Request #111421 · pytorch/pytorch · GitHub
Skip to content

Conversation

@masnesral
Copy link
Contributor

@masnesral masnesral commented Oct 17, 2023

Stack from ghstack (oldest at bottom):

Summary: Add support for caching graphs that have tensor args with symbolic shapes. The high-level appraoch is to serialize guards with the on-disk cached object and validating those guards pass before serving a cached object.

Test Plan: New unit tests

cc @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @peterbell10 @ipiszy @yf225 @chenyang78 @kadeng @muchulee8 @aakhundov @ColinPeppler

Summary: Add support for caching graphs that have tensor args with symbolic shapes. The high-level appraoch is to serialize guards with the on-disk cached object and validating those guards pass before serving a cached object.

Test Plan: New unit tests

[ghstack-poisoned]
@pytorch-bot pytorch-bot bot added the release notes: fx release notes category label Oct 17, 2023
@pytorch-bot
Copy link

pytorch-bot bot commented Oct 17, 2023

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/111421

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit e0135af with merge base 12a9e09 (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

masnesral added a commit that referenced this pull request Oct 17, 2023
Summary: Add support for caching graphs that have tensor args with symbolic shapes. The high-level appraoch is to serialize guards with the on-disk cached object and validating those guards pass before serving a cached object.

Test Plan: New unit tests

ghstack-source-id: 4b027a0
Pull Request resolved: #111421
self.assertEqual(fn(a, b), compiled_fn(a, b))
self.assertEqual(counters["inductor"]["fxgraph_cache_miss"], 1)
self.assertEqual(counters["inductor"]["fxgraph_cache_hit"], 1)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What's missing here is validation that guards were added after loading from the cache. Still investigating, but does anyone know if there's a straightforward way to access the guards in a structured way so I can do some validation here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ezyang do you have a recommendation? If I have a torch.compiled function, is there a straightforward way to see that our guards were properly added even in the case of a cache hit?

.decode("utf-8")
.lower()
)
return "c" + sha256_hash(hashing_str)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The prefix character seems to be the existing scheme here to differentiate different kinds of hashes for different types of cached objects. Probably I should use an enum. TODO.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah, when I wrote the original PR that was my thought too

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there any reason not to use a longer, more descriptive prefix? It's not like we have a file path length limitation or something

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was sticking with the convention already started, but can definitely turn it into a a more descriptive string. There's also an existing precedent for adding an extension, e.g., "cg" which I guessed means "code graph"?

self.fx_args = fx_args
def __init__(
self,
gm: torch.fx.GraphModule,
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since we need to access the example_args, it was convenient to make the gm and example_args arguments explicit rather than packaging them up in a list.

# in an in-memory cache after loading from disk.
@classmethod
def save_graph(cls, key: str, compiled_graph: CompiledFxGraph):
@staticmethod
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I kept this implementation under the FxGraphCache class, but all the methods are static. I prefer the namespacing of putting these methods in the class, but lemme know if a class of all static methods hurts your eyeballs.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sometimes people prefer @classmethod and cls.method for this reason i think. up to you.

if hit:
# Now re-evaluate to add the guards to the current shape env.
# We have to clear the `evaluate_expr` lru_cache to force evaluation.
shape_env.evaluate_expr.cache_clear()
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please comment. I don't know if clearing the whole cache is problematic, e.g., from a performance perspective. I could, for example, introduce a new context manager analogous to suppress_guards() above that selectively uses a cached vs. non-cached version of shape_env.evaluate_expr, but I thought it was worth asking if that's overkill.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Seems off to me that we would clear cache. leaving this to @ezyang to comment

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree it seems kinda weird that we would need to clear the cache. If a cache hit is preventing you from adding a guard, then that should mean that that guard was already evaluated, no?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

See some of the comments above. The approach here is to:

  1. Load a possible match from the cache.
  2. Evaluate the guards, but in a mode that does not modify the guards in the current environment, because: in the case that there's a miss, we don't actually want the current env to change.
  3. If there's a hit, only then re-evaluate to cause the guards to be added.

The problem here is that the caching at evaluate_expr() interferes with #3

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@masnesral I think a better solution would be to customize the caching on evaluate_expr so that it respects whether or not guards were suppressed or not. WDYT?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ezyang is that correct though? If the expression has been evaluated before (even guards suppressed), don't we still want to evaluate it again in order to get the guards added?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm, well I guess it didn't matter when the hints were ints but it does matter if the hints are symints. I mean, I wouldn't be opposed to just turning off this caching entirely when you're playing funny tricks with symint hints.


_boxed_call: Optional[bool] = None

def __init__(
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I preferred a ctor here rather than extracting all these fields at the caller.

# Inputs to fx_codegen_and_compile
# Anything that affects codegen should go here, so if the signature
# of fx_codegen_and_compile changes, the list and dict should be updated accordingly
graph_args = [gm, example_inputs]
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

See comment above. I changed this to be explicit about the two args (gm, example_inputs)


self._check_translation_validate()
return exprs

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The changes here are just about splitting apart the evaluate_guards_for_args function into smaller pieces that I can use for the cache impl. Namely, I need separate phases for creating the guards expression and evaluating it (in a different context)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Seems like a good comment to put into the code comments

@masnesral masnesral added the topic: not user facing topic category label Oct 17, 2023
@masnesral masnesral marked this pull request as ready for review October 17, 2023 15:59
…symbolic shapes"

Summary: Add support for caching graphs that have tensor args with symbolic shapes. The high-level appraoch is to serialize guards with the on-disk cached object and validating those guards pass before serving a cached object.

Test Plan: New unit tests

cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx peterbell10 ipiszy yf225 chenyang78 kadeng muchulee8 aakhundov ColinPeppler

[ghstack-poisoned]
masnesral added a commit that referenced this pull request Oct 17, 2023
Summary: Add support for caching graphs that have tensor args with symbolic shapes. The high-level appraoch is to serialize guards with the on-disk cached object and validating those guards pass before serving a cached object.

Test Plan: New unit tests

ghstack-source-id: bc2da94
Pull Request resolved: #111421
write(pickle.dumps(disk_compiled_graph), "cg", extra=key, hash_type="cg")

@classmethod
def load_graph(cls, cg_path: str) -> CompiledFxGraph:
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The diff looks weird here. load_graph is replaced by lookup_graph above and the implementation here is for save_graph. So ignore the red lines I guess.

Copy link
Contributor

@eellison eellison left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

looks good but let's wait for full review from @ezyang

"""
See FxGraphCachePickler. Custom reducer to pickle SymInts.
"""
# For hashing purposes, we only care about the name of the symbol and
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It might be useful in the future to canonicalize symints so that hash(tensor([s1, s1, s2])) == hash(tensor([s3, s3, s4])).. not needed now

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's not clear to me that this is sound, given https://github.com/pytorch/pytorch/pull/111421/files#diff-c9b517f8db609ffa866804dfa2689188a4fee20abacaa0b0dca91625c1b5cb8dR705

if you say [s0, s1] == [s4, s3], you have to make sure that you know how to flip the SymInt arguments when eval'ing the guards. That sounds difficult.

# in an in-memory cache after loading from disk.
@classmethod
def save_graph(cls, key: str, compiled_graph: CompiledFxGraph):
@staticmethod
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sometimes people prefer @classmethod and cls.method for this reason i think. up to you.

if hit:
# Now re-evaluate to add the guards to the current shape env.
# We have to clear the `evaluate_expr` lru_cache to force evaluation.
shape_env.evaluate_expr.cache_clear()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Seems off to me that we would clear cache. leaving this to @ezyang to comment

Comment on lines 757 to 758
path = os.path.join(subdir, sha256_hash(content) + ".cg")
write_atomic(path, content)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do we need FileLock here ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is held from the caller

FxGraph (the graph module, graph inputs, system settings etc.) into an
FxGraphCacheDetails object, pickle it, and compute a hash for the key.
See FxGraphCachePickler.
- Among the metadata we store, we also include a guards expression that's
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it possible to end up loading a more generic graph from cache than we need?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ohhhh, yes. That's a nice catch. So really I need to consider all versions and pick the "best" option, don't I? Hmm, is there any obvious criteria I can use to determine whether one option is more appropriate than the other? cc @eellison

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think maybe as a follow up - not needed in initial pr imo. Generally the behavior for symbolic shapes is to reuse a compilation if it works.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suppose you have a cache for (s0, 4) and another for (5, s1). There is no unique best choice cache entry to load.

Probably the only way to actually do this is some sort of autotuning based thing, where for any given concrete size you've benchmarked which one runs fastest, and you load that particular one.

if hit:
# Now re-evaluate to add the guards to the current shape env.
# We have to clear the `evaluate_expr` lru_cache to force evaluation.
shape_env.evaluate_expr.cache_clear()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree it seems kinda weird that we would need to clear the cache. If a cache hit is preventing you from adding a guard, then that should mean that that guard was already evaluated, no?

@parametrize("device", ("cuda", "cpu"))
@parametrize("dtype", (torch.float32, torch.bfloat16))
def test_cache_load_function(self, device, dtype):
@parametrize("dynamic", (False, True))
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Given that dynamic=None has a different behavior compared to True and False, it might be worth adding here as an option too

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So it turns out there's a shortcoming with that change. None seems to be equivalent to False in terms of the FX graph that gets generated (at least for these tests). So I'd need to rearchitect the tests slightly to make sure each test gets a clean tmp directory. I can do that, but I sorta liked the current behavior of leaving the tmp directory intact for the run of the full set of tests because we'd catch cases of a cache hit when we shouldn't (see my comment on setUpClass)

@parametrize("device", ("cuda", "cpu"))
@parametrize("dtype", (torch.float32, torch.bfloat16))
def test_cache_load_model(self, device, dtype):
@parametrize("dynamic", (False, True))
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same as previous comment


# Mark all tensor arg dimensions as dynamic to cause all shapes
# to be symbolic
def rand_dynamic(*args):
Copy link
Collaborator

@ani300 ani300 Oct 19, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is equivalent to dynamic=True on the compile function

… symbolic shapes"

Summary: Add support for caching graphs that have tensor args with symbolic shapes. The high-level appraoch is to serialize guards with the on-disk cached object and validating those guards pass before serving a cached object.

Test Plan: New unit tests

cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx peterbell10 ipiszy yf225 chenyang78 kadeng muchulee8 aakhundov ColinPeppler

[ghstack-poisoned]
masnesral added a commit that referenced this pull request Oct 26, 2023
Summary: Add support for caching graphs that have tensor args with symbolic shapes. The high-level appraoch is to serialize guards with the on-disk cached object and validating those guards pass before serving a cached object.

Test Plan: New unit tests

ghstack-source-id: 459114b
Pull Request resolved: #111421
…hapes"

Summary: Add support for caching graphs that have tensor args with symbolic shapes. The high-level appraoch is to serialize guards with the on-disk cached object and validating those guards pass before serving a cached object.

Test Plan: New unit tests

cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx peterbell10 ipiszy yf225 chenyang78 kadeng muchulee8 aakhundov ColinPeppler

[ghstack-poisoned]
masnesral added a commit that referenced this pull request Oct 28, 2023
Summary: Add support for caching graphs that have tensor args with symbolic shapes. The high-level appraoch is to serialize guards with the on-disk cached object and validating those guards pass before serving a cached object.

Test Plan: New unit tests

ghstack-source-id: 12a9195
Pull Request resolved: #111421
…hapes"

Summary: Add support for caching graphs that have tensor args with symbolic shapes. The high-level appraoch is to serialize guards with the on-disk cached object and validating those guards pass before serving a cached object.

Test Plan: New unit tests

cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx peterbell10 ipiszy yf225 chenyang78 kadeng muchulee8 aakhundov ColinPeppler

[ghstack-poisoned]
…hapes"

Summary: Add support for caching graphs that have tensor args with symbolic shapes. The high-level appraoch is to serialize guards with the on-disk cached object and validating those guards pass before serving a cached object.

Test Plan: New unit tests

cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx peterbell10 ipiszy yf225 chenyang78 kadeng muchulee8 aakhundov ColinPeppler

[ghstack-poisoned]
…hapes"

Summary: Add support for caching graphs that have tensor args with symbolic shapes. The high-level appraoch is to serialize guards with the on-disk cached object and validating those guards pass before serving a cached object.

Test Plan: New unit tests

cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx peterbell10 ipiszy yf225 chenyang78 kadeng muchulee8 aakhundov ColinPeppler

[ghstack-poisoned]
masnesral added a commit that referenced this pull request Oct 29, 2023
Summary: Add support for caching graphs that have tensor args with symbolic shapes. The high-level appraoch is to serialize guards with the on-disk cached object and validating those guards pass before serving a cached object.

Test Plan: New unit tests

ghstack-source-id: 478ac7f
Pull Request resolved: #111421
…hapes"

Summary: Add support for caching graphs that have tensor args with symbolic shapes. The high-level appraoch is to serialize guards with the on-disk cached object and validating those guards pass before serving a cached object.

Test Plan: New unit tests

cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx peterbell10 ipiszy yf225 chenyang78 kadeng muchulee8 aakhundov ColinPeppler

[ghstack-poisoned]
masnesral added a commit that referenced this pull request Oct 29, 2023
Summary: Add support for caching graphs that have tensor args with symbolic shapes. The high-level appraoch is to serialize guards with the on-disk cached object and validating those guards pass before serving a cached object.

Test Plan: New unit tests

ghstack-source-id: 1d36a4e
Pull Request resolved: #111421
@masnesral masnesral changed the title [RFC][inductor] FX graph cache: Add support for symbolic shapes [inductor] FX graph cache: Add support for symbolic shapes Oct 30, 2023
Summary: Add support for caching graphs that have tensor args with symbolic shapes. The high-level appraoch is to serialize guards with the on-disk cached object and validating those guards pass before serving a cached object.

Test Plan: New unit tests

[ghstack-poisoned]
Summary: Add support for caching graphs that have tensor args with symbolic shapes. The high-level appraoch is to serialize guards with the on-disk cached object and validating those guards pass before serving a cached object.

Test Plan: New unit tests

[ghstack-poisoned]
masnesral added a commit that referenced this pull request Oct 31, 2023
Summary: Add support for caching graphs that have tensor args with symbolic shapes. The high-level appraoch is to serialize guards with the on-disk cached object and validating those guards pass before serving a cached object.

Test Plan: New unit tests

ghstack-source-id: 7813788
Pull Request resolved: #111421
@masnesral
Copy link
Contributor Author

@pytorchbot merge

@pytorch-bot pytorch-bot bot added the ciflow/trunk Trigger trunk jobs on your pull request label Oct 31, 2023
@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

@facebook-github-bot facebook-github-bot deleted the gh/masnesral/6/head branch November 4, 2023 14:26
xuhancn pushed a commit to xuhancn/pytorch that referenced this pull request Nov 7, 2023
…11421)

Summary: Add support for caching graphs that have tensor args with symbolic shapes. The high-level appraoch is to serialize guards with the on-disk cached object and validating those guards pass before serving a cached object.

Test Plan: New unit tests

Pull Request resolved: pytorch#111421
Approved by: https://github.com/ezyang
Skylion007 pushed a commit to Skylion007/pytorch that referenced this pull request Nov 14, 2023
…11421)

Summary: Add support for caching graphs that have tensor args with symbolic shapes. The high-level appraoch is to serialize guards with the on-disk cached object and validating those guards pass before serving a cached object.

Test Plan: New unit tests

Pull Request resolved: pytorch#111421
Approved by: https://github.com/ezyang
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ciflow/inductor ciflow/trunk Trigger trunk jobs on your pull request Merged module: inductor release notes: fx release notes category topic: not user facing topic category

Projects

None yet

Development

Successfully merging this pull request may close these issues.

6 participants