KEMBAR78
AOTI Minifier by yushangdi · Pull Request #139351 · pytorch/pytorch · GitHub
Skip to content

Conversation

@yushangdi
Copy link
Contributor

@yushangdi yushangdi commented Oct 31, 2024

@pytorch-bot
Copy link

pytorch-bot bot commented Oct 31, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/139351

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit 2787ffb with merge base b8cf324 (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@yushangdi yushangdi marked this pull request as ready for review October 31, 2024 22:24
@yushangdi yushangdi requested a review from desertfire October 31, 2024 22:24
@yushangdi yushangdi changed the title Aoti package minifier AOTI Minifier Oct 31, 2024
# GPU Hardware Info:
# NVIDIA PG509-210 : 8
exported_program = torch.export.load('/data/users/shangdiy/pytorch/torch_compile_debug/run_2024_10_31_16_48_02_720863-pid_3598491/minifier/checkpoints/exported_program.pt2')
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I expect the repro file to contain a trimmed model code which runs the full export-compile-run flow and then reproduces the problem. Loading another .pt2 file is not intuitive here.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I expect the repro file to contain a trimmed model code which runs the full export-compile-run flow and then reproduces the problem. Loading another .pt2 file is not intuitive here.

@desertfire The .pt2 file is a trimmed model. I'm not sure how we can just include the model code instead of using an exported_program (the existing torch.compile minifier does this, but it's buggy). How can we convert the model code into a string and preserve the model safely?

If the trimmed model is as simple as a few nodes, that's possible. But if the trimmed model is more complicated, e.g. contains submodules or parameters, then converting it into a string and load inputs/state_dict is non-trivial anymore.

Another point is that the result of the minifier is a GraphModule object. It seems a bit weird that we first convert the GraphModule object into some nn.Module code, and then export it back to a GraphModule?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When we use TORCH_COMPILE_DEBUG=1, it will generate fx_graph_runnable.py. Maybe we can do something similar here?

with self.fopen("fx_graph_runnable.py") as fd:
save_dir = None
if torch._inductor.config.trace.save_real_tensors:
inputs = torch._subclasses.fake_utils.try_convert_fake_to_real(inputs)
save_dir = os.path.dirname(fd.name)
# dont try to use stable hash torchinductor compilation if saving real tensors
# and avoid recursively trying to save real tensors inside of the inductor compilation
# regardless
stable_hash = torch._inductor.config.trace.save_real_tensors
with torch._inductor.config.patch(
{"trace.enabled": False, "trace.save_real_tensors": False}
):
save_graph_repro(
fd,
gm,
inputs,
"inductor",
save_dir=save_dir,
stable_hash=stable_hash,
)

Copy link
Contributor Author

@yushangdi yushangdi Nov 5, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When we use TORCH_COMPILE_DEBUG=1, it will generate fx_graph_runnable.py. Maybe we can do something similar here?

with self.fopen("fx_graph_runnable.py") as fd:
save_dir = None
if torch._inductor.config.trace.save_real_tensors:
inputs = torch._subclasses.fake_utils.try_convert_fake_to_real(inputs)
save_dir = os.path.dirname(fd.name)
# dont try to use stable hash torchinductor compilation if saving real tensors
# and avoid recursively trying to save real tensors inside of the inductor compilation
# regardless
stable_hash = torch._inductor.config.trace.save_real_tensors
with torch._inductor.config.patch(
{"trace.enabled": False, "trace.save_real_tensors": False}
):
save_graph_repro(
fd,
gm,
inputs,
"inductor",
save_dir=save_dir,
stable_hash=stable_hash,
)

yeah, we can do this, but I think this is not as robust as storing an exported program. This uses torch._dynamo.repro.after_aot.save_graph_repro under the hood, which only works for flattened graphs. It doesn't work out-of-the-box if there are submodules in the graph.

But maybe this is fine since AOTI runs on ep.module() which is already flattened, so the repro are always a flattened graph?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

so the repro are always a flattened graph?

I think so.

My main concern with generating .pt2 is for OSS users's ability to see the generated code. Because some OSS users don't want to share their whole model code, but if they see the generated code is small enough, they are more willing to share it as a repro. With .pt2, they can unzip and examine the minimized code in theory, but that just adds an extra step and discourages people to share.

For that reason, I think maybe we can just print the minimized exported graph as a string comment in this file, so users can immediately see what is contained in the graph.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

so the repro are always a flattened graph?

I think so.

My main concern with generating .pt2 is for OSS users's ability to see the generated code. Because some OSS users don't want to share their whole model code, but if they see the generated code is small enough, they are more willing to share it as a repro. With .pt2, they can unzip and examine the minimized code in theory, but that just adds an extra step and discourages people to share.

For that reason, I think maybe we can just print the minimized exported graph as a string comment in this file, so users can immediately see what is contained in the graph.

I modified the PR to return the exported graph as a string now. The doc is also updated.


if load_and_run:
compiled_model = aoti_load_package(package_path)
aoti_result = compiled_model(*args)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We could compare with eager result here and thus enable accuracy minifier. Ok to do it in a follow-up PR.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We could compare with eager result here and thus enable accuracy minifier. Ok to do it in a follow-up PR.

yeah, I can do it in a follow up

verbose_progress = False

# dump an aoti minifier if program errors
dump_aoti_minifier: bool = os.environ.get("DUMP_AOTI_MINIFIER", "0") == "1"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would recommend moving this under aot_inductor.

@yushangdi
Copy link
Contributor Author

@pytorchbot merge

@pytorch-bot pytorch-bot bot added the ciflow/trunk Trigger trunk jobs on your pull request label Nov 7, 2024
@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

@henrylhtsang
Copy link
Contributor

Thanks for the work. Not sure how much work that would be, but is it possible to print some metadata for the input arguments, so users can try to reproduce them without needing to load local files?

e.g. for tensor, print shape and dtype. For python constants, print everything.

@yushangdi
Copy link
Contributor Author

Thanks for the work. Not sure how much work that would be, but is it possible to print some metadata for the input arguments, so users can try to reproduce them without needing to load local files?

e.g. for tensor, print shape and dtype. For python constants, print everything.

Thanks for the suggestion! Yeah I can add that, it shouldn't be too hard.

pobin6 pushed a commit to pobin6/pytorch that referenced this pull request Dec 5, 2024
See documentation at https://docs-preview.pytorch.org/pytorch/pytorch/139351/torch.compiler_aot_inductor_minifier.html.

Add a minifier for AOTI.

Test Plan:
python test/inductor/test_minifier.py

Pull Request resolved: pytorch#139351
Approved by: https://github.com/desertfire
@github-actions github-actions bot deleted the aoti_package_minifier branch December 9, 2024 02:14
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants