-
Notifications
You must be signed in to change notification settings - Fork 25.7k
[inductor] FX graph cache: Add support for symbolic shapes #111421
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Summary: Add support for caching graphs that have tensor args with symbolic shapes. The high-level appraoch is to serialize guards with the on-disk cached object and validating those guards pass before serving a cached object. Test Plan: New unit tests [ghstack-poisoned]
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/111421
Note: Links to docs will display an error until the docs builds have been completed. ✅ No FailuresAs of commit e0135af with merge base 12a9e09 ( This comment was automatically generated by Dr. CI and updates every 15 minutes. |
Summary: Add support for caching graphs that have tensor args with symbolic shapes. The high-level appraoch is to serialize guards with the on-disk cached object and validating those guards pass before serving a cached object. Test Plan: New unit tests ghstack-source-id: 4b027a0 Pull Request resolved: #111421
test/inductor/test_codecache.py
Outdated
| self.assertEqual(fn(a, b), compiled_fn(a, b)) | ||
| self.assertEqual(counters["inductor"]["fxgraph_cache_miss"], 1) | ||
| self.assertEqual(counters["inductor"]["fxgraph_cache_hit"], 1) | ||
|
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What's missing here is validation that guards were added after loading from the cache. Still investigating, but does anyone know if there's a straightforward way to access the guards in a structured way so I can do some validation here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@ezyang do you have a recommendation? If I have a torch.compiled function, is there a straightforward way to see that our guards were properly added even in the case of a cache hit?
| .decode("utf-8") | ||
| .lower() | ||
| ) | ||
| return "c" + sha256_hash(hashing_str) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The prefix character seems to be the existing scheme here to differentiate different kinds of hashes for different types of cached objects. Probably I should use an enum. TODO.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yeah, when I wrote the original PR that was my thought too
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is there any reason not to use a longer, more descriptive prefix? It's not like we have a file path length limitation or something
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I was sticking with the convention already started, but can definitely turn it into a a more descriptive string. There's also an existing precedent for adding an extension, e.g., "cg" which I guessed means "code graph"?
| self.fx_args = fx_args | ||
| def __init__( | ||
| self, | ||
| gm: torch.fx.GraphModule, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Since we need to access the example_args, it was convenient to make the gm and example_args arguments explicit rather than packaging them up in a list.
| # in an in-memory cache after loading from disk. | ||
| @classmethod | ||
| def save_graph(cls, key: str, compiled_graph: CompiledFxGraph): | ||
| @staticmethod |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I kept this implementation under the FxGraphCache class, but all the methods are static. I prefer the namespacing of putting these methods in the class, but lemme know if a class of all static methods hurts your eyeballs.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sometimes people prefer @classmethod and cls.method for this reason i think. up to you.
torch/_inductor/codecache.py
Outdated
| if hit: | ||
| # Now re-evaluate to add the guards to the current shape env. | ||
| # We have to clear the `evaluate_expr` lru_cache to force evaluation. | ||
| shape_env.evaluate_expr.cache_clear() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please comment. I don't know if clearing the whole cache is problematic, e.g., from a performance perspective. I could, for example, introduce a new context manager analogous to suppress_guards() above that selectively uses a cached vs. non-cached version of shape_env.evaluate_expr, but I thought it was worth asking if that's overkill.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Seems off to me that we would clear cache. leaving this to @ezyang to comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I agree it seems kinda weird that we would need to clear the cache. If a cache hit is preventing you from adding a guard, then that should mean that that guard was already evaluated, no?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
See some of the comments above. The approach here is to:
- Load a possible match from the cache.
- Evaluate the guards, but in a mode that does not modify the guards in the current environment, because: in the case that there's a miss, we don't actually want the current env to change.
- If there's a hit, only then re-evaluate to cause the guards to be added.
The problem here is that the caching at evaluate_expr() interferes with #3
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@masnesral I think a better solution would be to customize the caching on evaluate_expr so that it respects whether or not guards were suppressed or not. WDYT?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@ezyang is that correct though? If the expression has been evaluated before (even guards suppressed), don't we still want to evaluate it again in order to get the guards added?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hmm, well I guess it didn't matter when the hints were ints but it does matter if the hints are symints. I mean, I wouldn't be opposed to just turning off this caching entirely when you're playing funny tricks with symint hints.
|
|
||
| _boxed_call: Optional[bool] = None | ||
|
|
||
| def __init__( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I preferred a ctor here rather than extracting all these fields at the caller.
| # Inputs to fx_codegen_and_compile | ||
| # Anything that affects codegen should go here, so if the signature | ||
| # of fx_codegen_and_compile changes, the list and dict should be updated accordingly | ||
| graph_args = [gm, example_inputs] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
See comment above. I changed this to be explicit about the two args (gm, example_inputs)
|
|
||
| self._check_translation_validate() | ||
| return exprs | ||
|
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The changes here are just about splitting apart the evaluate_guards_for_args function into smaller pieces that I can use for the cache impl. Namely, I need separate phases for creating the guards expression and evaluating it (in a different context)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Seems like a good comment to put into the code comments
…symbolic shapes" Summary: Add support for caching graphs that have tensor args with symbolic shapes. The high-level appraoch is to serialize guards with the on-disk cached object and validating those guards pass before serving a cached object. Test Plan: New unit tests cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx peterbell10 ipiszy yf225 chenyang78 kadeng muchulee8 aakhundov ColinPeppler [ghstack-poisoned]
Summary: Add support for caching graphs that have tensor args with symbolic shapes. The high-level appraoch is to serialize guards with the on-disk cached object and validating those guards pass before serving a cached object. Test Plan: New unit tests ghstack-source-id: bc2da94 Pull Request resolved: #111421
| write(pickle.dumps(disk_compiled_graph), "cg", extra=key, hash_type="cg") | ||
|
|
||
| @classmethod | ||
| def load_graph(cls, cg_path: str) -> CompiledFxGraph: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The diff looks weird here. load_graph is replaced by lookup_graph above and the implementation here is for save_graph. So ignore the red lines I guess.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
looks good but let's wait for full review from @ezyang
| """ | ||
| See FxGraphCachePickler. Custom reducer to pickle SymInts. | ||
| """ | ||
| # For hashing purposes, we only care about the name of the symbol and |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It might be useful in the future to canonicalize symints so that hash(tensor([s1, s1, s2])) == hash(tensor([s3, s3, s4])).. not needed now
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's not clear to me that this is sound, given https://github.com/pytorch/pytorch/pull/111421/files#diff-c9b517f8db609ffa866804dfa2689188a4fee20abacaa0b0dca91625c1b5cb8dR705
if you say [s0, s1] == [s4, s3], you have to make sure that you know how to flip the SymInt arguments when eval'ing the guards. That sounds difficult.
| # in an in-memory cache after loading from disk. | ||
| @classmethod | ||
| def save_graph(cls, key: str, compiled_graph: CompiledFxGraph): | ||
| @staticmethod |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sometimes people prefer @classmethod and cls.method for this reason i think. up to you.
torch/_inductor/codecache.py
Outdated
| if hit: | ||
| # Now re-evaluate to add the guards to the current shape env. | ||
| # We have to clear the `evaluate_expr` lru_cache to force evaluation. | ||
| shape_env.evaluate_expr.cache_clear() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Seems off to me that we would clear cache. leaving this to @ezyang to comment
torch/_inductor/codecache.py
Outdated
| path = os.path.join(subdir, sha256_hash(content) + ".cg") | ||
| write_atomic(path, content) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
do we need FileLock here ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It is held from the caller
| FxGraph (the graph module, graph inputs, system settings etc.) into an | ||
| FxGraphCacheDetails object, pickle it, and compute a hash for the key. | ||
| See FxGraphCachePickler. | ||
| - Among the metadata we store, we also include a guards expression that's |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is it possible to end up loading a more generic graph from cache than we need?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ohhhh, yes. That's a nice catch. So really I need to consider all versions and pick the "best" option, don't I? Hmm, is there any obvious criteria I can use to determine whether one option is more appropriate than the other? cc @eellison
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think maybe as a follow up - not needed in initial pr imo. Generally the behavior for symbolic shapes is to reuse a compilation if it works.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Suppose you have a cache for (s0, 4) and another for (5, s1). There is no unique best choice cache entry to load.
Probably the only way to actually do this is some sort of autotuning based thing, where for any given concrete size you've benchmarked which one runs fastest, and you load that particular one.
torch/_inductor/codecache.py
Outdated
| if hit: | ||
| # Now re-evaluate to add the guards to the current shape env. | ||
| # We have to clear the `evaluate_expr` lru_cache to force evaluation. | ||
| shape_env.evaluate_expr.cache_clear() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I agree it seems kinda weird that we would need to clear the cache. If a cache hit is preventing you from adding a guard, then that should mean that that guard was already evaluated, no?
| @parametrize("device", ("cuda", "cpu")) | ||
| @parametrize("dtype", (torch.float32, torch.bfloat16)) | ||
| def test_cache_load_function(self, device, dtype): | ||
| @parametrize("dynamic", (False, True)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Given that dynamic=None has a different behavior compared to True and False, it might be worth adding here as an option too
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
So it turns out there's a shortcoming with that change. None seems to be equivalent to False in terms of the FX graph that gets generated (at least for these tests). So I'd need to rearchitect the tests slightly to make sure each test gets a clean tmp directory. I can do that, but I sorta liked the current behavior of leaving the tmp directory intact for the run of the full set of tests because we'd catch cases of a cache hit when we shouldn't (see my comment on setUpClass)
| @parametrize("device", ("cuda", "cpu")) | ||
| @parametrize("dtype", (torch.float32, torch.bfloat16)) | ||
| def test_cache_load_model(self, device, dtype): | ||
| @parametrize("dynamic", (False, True)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same as previous comment
test/inductor/test_codecache.py
Outdated
|
|
||
| # Mark all tensor arg dimensions as dynamic to cause all shapes | ||
| # to be symbolic | ||
| def rand_dynamic(*args): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this is equivalent to dynamic=True on the compile function
… symbolic shapes" Summary: Add support for caching graphs that have tensor args with symbolic shapes. The high-level appraoch is to serialize guards with the on-disk cached object and validating those guards pass before serving a cached object. Test Plan: New unit tests cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx peterbell10 ipiszy yf225 chenyang78 kadeng muchulee8 aakhundov ColinPeppler [ghstack-poisoned]
Summary: Add support for caching graphs that have tensor args with symbolic shapes. The high-level appraoch is to serialize guards with the on-disk cached object and validating those guards pass before serving a cached object. Test Plan: New unit tests ghstack-source-id: 459114b Pull Request resolved: #111421
…hapes" Summary: Add support for caching graphs that have tensor args with symbolic shapes. The high-level appraoch is to serialize guards with the on-disk cached object and validating those guards pass before serving a cached object. Test Plan: New unit tests cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx peterbell10 ipiszy yf225 chenyang78 kadeng muchulee8 aakhundov ColinPeppler [ghstack-poisoned]
Summary: Add support for caching graphs that have tensor args with symbolic shapes. The high-level appraoch is to serialize guards with the on-disk cached object and validating those guards pass before serving a cached object. Test Plan: New unit tests ghstack-source-id: 12a9195 Pull Request resolved: #111421
…hapes" Summary: Add support for caching graphs that have tensor args with symbolic shapes. The high-level appraoch is to serialize guards with the on-disk cached object and validating those guards pass before serving a cached object. Test Plan: New unit tests cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx peterbell10 ipiszy yf225 chenyang78 kadeng muchulee8 aakhundov ColinPeppler [ghstack-poisoned]
…hapes" Summary: Add support for caching graphs that have tensor args with symbolic shapes. The high-level appraoch is to serialize guards with the on-disk cached object and validating those guards pass before serving a cached object. Test Plan: New unit tests cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx peterbell10 ipiszy yf225 chenyang78 kadeng muchulee8 aakhundov ColinPeppler [ghstack-poisoned]
…hapes" Summary: Add support for caching graphs that have tensor args with symbolic shapes. The high-level appraoch is to serialize guards with the on-disk cached object and validating those guards pass before serving a cached object. Test Plan: New unit tests cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx peterbell10 ipiszy yf225 chenyang78 kadeng muchulee8 aakhundov ColinPeppler [ghstack-poisoned]
Summary: Add support for caching graphs that have tensor args with symbolic shapes. The high-level appraoch is to serialize guards with the on-disk cached object and validating those guards pass before serving a cached object. Test Plan: New unit tests ghstack-source-id: 478ac7f Pull Request resolved: #111421
…hapes" Summary: Add support for caching graphs that have tensor args with symbolic shapes. The high-level appraoch is to serialize guards with the on-disk cached object and validating those guards pass before serving a cached object. Test Plan: New unit tests cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx peterbell10 ipiszy yf225 chenyang78 kadeng muchulee8 aakhundov ColinPeppler [ghstack-poisoned]
Summary: Add support for caching graphs that have tensor args with symbolic shapes. The high-level appraoch is to serialize guards with the on-disk cached object and validating those guards pass before serving a cached object. Test Plan: New unit tests ghstack-source-id: 1d36a4e Pull Request resolved: #111421
Summary: Add support for caching graphs that have tensor args with symbolic shapes. The high-level appraoch is to serialize guards with the on-disk cached object and validating those guards pass before serving a cached object. Test Plan: New unit tests [ghstack-poisoned]
Summary: Add support for caching graphs that have tensor args with symbolic shapes. The high-level appraoch is to serialize guards with the on-disk cached object and validating those guards pass before serving a cached object. Test Plan: New unit tests [ghstack-poisoned]
Summary: Add support for caching graphs that have tensor args with symbolic shapes. The high-level appraoch is to serialize guards with the on-disk cached object and validating those guards pass before serving a cached object. Test Plan: New unit tests ghstack-source-id: 7813788 Pull Request resolved: #111421
|
@pytorchbot merge |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
…11421) Summary: Add support for caching graphs that have tensor args with symbolic shapes. The high-level appraoch is to serialize guards with the on-disk cached object and validating those guards pass before serving a cached object. Test Plan: New unit tests Pull Request resolved: pytorch#111421 Approved by: https://github.com/ezyang
…11421) Summary: Add support for caching graphs that have tensor args with symbolic shapes. The high-level appraoch is to serialize guards with the on-disk cached object and validating those guards pass before serving a cached object. Test Plan: New unit tests Pull Request resolved: pytorch#111421 Approved by: https://github.com/ezyang
Stack from ghstack (oldest at bottom):
Summary: Add support for caching graphs that have tensor args with symbolic shapes. The high-level appraoch is to serialize guards with the on-disk cached object and validating those guards pass before serving a cached object.
Test Plan: New unit tests
cc @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @peterbell10 @ipiszy @yf225 @chenyang78 @kadeng @muchulee8 @aakhundov @ColinPeppler