-
Notifications
You must be signed in to change notification settings - Fork 6.4k
Rewrite AuraFlowPatchEmbed.pe_selection_index_based_on_dim to be torch.compile compatible #11297
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Rewrite AuraFlowPatchEmbed.pe_selection_index_based_on_dim to be torch.compile compatible #11297
Conversation
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
Feel free to update docs including the 0.3 note :)
I think if with and without the changes we can get same numerical outputs, that should be more than enough. @StrongerXi, wanna investigate this? |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks a lot for your efforts here and also for testing torch.compile() with this.
Just as an FYI, we're working on #11085 to have better testing for torch.compile.
For my understanding, does this PR solve the recompilation issues for higher resolutions?
should I update both So maybe we should have some detection or at least update docs? I know some people prefer 0.2 right now.
Are the current test sufficient to confirm or should I add something extra first? Any comparisons I should run on my end?
Correct, with this change so far I am not getting any more recompilations with AF loaded via GGUF. |
Thanks for taking so much effort to enable torch.compile here. This workstream is truly amazing! Cc @bobrenjc93 @laithsakka for dynamic shape guards related rewrite review. Might be a good rewrite to document in the dynamic shape manual. |
AFK currently. Please allow me some time to get back to you |
The tests in #11297 (comment) are sufficient. Thanks!
Well, when Does this answer your question? |
What I would also do is the following (perhaps in a separate PR): Add a new test class / method in https://github.com/huggingface/diffusers/blob/main/tests/pipelines/aura_flow/test_pipeline_aura_flow.py that checks no recompilation is triggered when we go for higher resolutions. I believe we won't need a pre-trained checkpoint for this. We could use the dummy model from
I can work on this and when ready ask for a review you and @anijain2305. WDYT? LM also know if this test case makes sense. Also, @AstraliteHeart if possible, it would be great to update the docs of AuraFlow with a section on no recompilations when using |
@yiyixuxu could also review this PR? This helps to make AuraFlow better compatible with |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
thanks!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks great, thanks!
cc @laithsakka you probably want to include a section on reducing recompiles by rewriting tensor indexing operations like this PR in your recompilations guide. Also you should probably either write a separate OSS version or publish the internal version of https://docs.google.com/document/d/1QgQLVBNKSYMeNbG5sEz_pwffL9PlKRHKXMI4ft3H9gA/edit?tab=t.0#heading=h.a37bpg8ay2f4
@AstraliteHeart just waiting for you to provide some confirmations to my comments above when you have time. We will then merge :) |
Updated the docs to reflect correct default values, I don't think we need 0.3 note, I assumed the values are not read from the model which was incorrect (see below).
rechecked the values populated from the config and you are correct
I would never say "no" to someone volunteering to write test but lmk if you want me to work on that.
For the compilation example, I believe the only special thing right now is torch.fx.experimental._config.use_duck_shape = False
transformer = AuraFlowTransformer2DModel.from_single_file(
"https://huggingface.co/city96/AuraFlow-v0.3-gguf/blob/main/aura_flow_0.3-Q2_K.gguf",
quantization_config=GGUFQuantizationConfig(compute_dtype=torch.bfloat16),
torch_dtype=torch.bfloat16,
)
pipeline = AuraFlowPipeline.from_pretrained(
"fal/AuraFlow-v0.3",
torch_dtype=torch.bfloat16,
transformer=transformer,
).to("cuda")
pipeline.transformer = torch.compile(pipeline.transformer, fullgraph=True, dynamic=True)
Happy to get this merged (but please check the doc update just in case). @yiyixuxu @bobrenjc93 thank you for having a look. |
Thank you! Can I push directly to your branch to include the snippet in #11297 (comment) in the AuraFlow pipeline docs? |
What does this PR do?
Updates AuraFlowPatchEmbed.pe_selection_index_based_on_dim so that the AuraFlowTransformer2DModel can be fully torch.compile(d)
Old and new code generate same images but I am not an expert enough to know if this has any bad impact on performance or hidden caveats.
I've noticed some weirdness while fixing this issue:
AuraFlowTransformer2DModel
in the docs haspos_embed_max_size (
int, defaults to 4096): Maximum positions to embed from the image latents.
and in the code
pos_embed_max_size: int = 1024,
but AFAIK for AuraFlow 0.3 it actually should be something like?
Fixes # Originally filled in torch - (issue)
Before submitting
documentation guidelines, and
here are tips on formatting docstrings.
Who can review?
@cloneofsimo @sayakpaul @yiyixuxu