-
Notifications
You must be signed in to change notification settings - Fork 6.4k
[LoRA] feat: lora support for SANA. #10234
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
|
|
||
| return noise, input_ids, pipeline_inputs | ||
|
|
||
| @unittest.skip("Not supported in Sana.") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Skipped tests are the same as Mochi.
| "prompt": "", | ||
| "negative_prompt": "", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Check this internal thread:
https://huggingface.slack.com/archives/C065E480NN9/p1734324025408149
|
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
|
So was not requested for review, but saw the latest commit on email notifications about I had shared the concern in a previous lora refactor PR and this comment. This is because I often find myself having to refer to the documentation for different pipelines instead of just being able to use one consistent parameter name to pass lora scale with, and it is frustrating because you wait for the pipeline to load only to find it fail immediately. I'm not sure if others resonate with this, but anyone using loras often will have faced this. We have |
|
Yeah I don't mind. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the super fast work! Looks good to merge after some of the more important reviews are addressed
| vae.to(dtype=torch.float32) | ||
| transformer.to(accelerator.device, dtype=weight_dtype) | ||
| # because Gemma2 is particularly suited for bfloat16. | ||
| text_encoder.to(dtype=torch.bfloat16) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we could instead load with torch_dtype=torch.bfloat16 and keept he same comment. This is because weight casting this way ignores _keep_modules_in_fp32. I could not get our numerical precision unit tests to match when using the two different ways when working on the integration PR
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oh this is coming from the example provided in https://huggingface.co/docs/diffusers/main/en/api/pipelines/sana#diffusers.SanaPipeline.__call__.example. In this case, we're doing the exact same thing and we are not fine-tuning the text encoder.
| ) | ||
|
|
||
| # VAE should always be kept in fp32 for SANA (?) | ||
| vae.to(dtype=torch.float32) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
FP32 should be good, but I'm not 100% sure. I think AutoencoderDC were all trained in bf16. Maybe @lawrence-cj can comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is just to be sure VAE's precision isn't a bottleneck for getting good quality training runs. This is anyway a small VAE, won't matter too much I guess.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes. AutoencoderDC is trained under BF16 and FP32 testing is also fine, just it'll cost a lot of additional GPU memory in FP32.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We're offloading it to CPU when it's not used and when cache_latents is supplied through the CLI, we will precompute the latents and delete the VAE. So, I guess okay for now?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I mean the VAE.decode() part will consume much more GPU memory if it runs in FP32, specially when the batch_size is more than 1, not the VAE model itself. Not sure if I understand right.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oh I think we should be good as we barely make use of decode() in training.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
OK, that's cool. Then the only concern is that when we visualize the training results during training
| clean_caption: bool = False, | ||
| max_sequence_length: int = 300, | ||
| complex_human_instruction: Optional[List[str]] = None, | ||
| lora_scale: Optional[float] = None, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Are we training text encoder? If not, we can remove these changes maybe
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This was to have no surprise for our users when text encoder training support is merged. It's common to see the encode_prompt() method being equipped with handling lora_scale.
Co-authored-by: Aryan <aryan@huggingface.co>
|
@a-r-r-o-w your comments have been addressed. @lawrence-cj could you review / test the training script if you want? |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
thanks!
|
Failing tests are unrelated and can safely be ignored. Will add a training test in a followup PR. |
Working on it. Will fine-tune the model using your pokemon dataset. |
* feat: lora support for SANA. * make fix-copies * rename test class. * attention_kwargs -> cross_attention_kwargs. * Revert "attention_kwargs -> cross_attention_kwargs." This reverts commit 23433bf. * exhaust 119 max line limit * sana lora fine-tuning script. * readme * add a note about the supported models. * Apply suggestions from code review Co-authored-by: Aryan <aryan@huggingface.co> * style * docs for attention_kwargs. * remove lora_scale from pag pipeline. * copy fix --------- Co-authored-by: Aryan <aryan@huggingface.co>
* feat: lora support for SANA. * make fix-copies * rename test class. * attention_kwargs -> cross_attention_kwargs. * Revert "attention_kwargs -> cross_attention_kwargs." This reverts commit 23433bf. * exhaust 119 max line limit * sana lora fine-tuning script. * readme * add a note about the supported models. * Apply suggestions from code review Co-authored-by: Aryan <aryan@huggingface.co> * style * docs for attention_kwargs. * remove lora_scale from pag pipeline. * copy fix --------- Co-authored-by: Aryan <aryan@huggingface.co>
What does this PR do?
Example LoRA fine-tuning command:
Notes
mixed_precision="fp16"is leading to NaN loss values despite the recommendation to use FP16 for "Efficient-Large-Model/Sana_1600M_1024px_diffusers".Results
https://wandb.ai/sayakpaul/dreambooth-sana-lora/runs/tf9fo8o6
Pre-trained LoRA: https://huggingface.co/sayakpaul/yarn_art_lora_sana