-
Notifications
You must be signed in to change notification settings - Fork 6.4k
[tests] refactor vae tests #9808
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
|
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
| else: | ||
| z = posterior.mode() | ||
| dec = self.decode(z) | ||
| dec = self.decode(z).sample |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Otherwise we return a tuple of DecoderOutput when return_dict=False.
| sample_size = ( | ||
| self.config.sample_size[0] | ||
| if isinstance(self.config.sample_size, (list, tuple)) | ||
| else self.config.sample_size | ||
| ) | ||
| self.tile_latent_min_size = int(sample_size / (2 ** (len(self.config.block_out_channels) - 1))) | ||
| self.tile_overlap_factor = 0.25 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Unused.
| output = [self._tiled_decode(x_slice) if self.use_tiling else self.decoder(x) for x_slice in x.split(1)] | ||
| output = [ | ||
| self._tiled_decode(x_slice) if self.use_tiling else self.decoder(x_slice) for x_slice in x.split(1) | ||
| ] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should use x_slice and not x.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could maybe further refactor this to how the current implementations of Cog/Mochi are with _decode method. A bit easier to understand code flow that way
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah sure feel free to club those in your PR.
tests/models/autoencoders/test_models_autoencoder_kl_allegro.py
Outdated
Show resolved
Hide resolved
|
@DN6 a gentle ping. |
| temb, | ||
| zq, | ||
| conv_cache=conv_cache.get(conv_cache_key), | ||
| conv_cache.get(conv_cache_key), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Because the torch.utils.checkpoint.checkpoint() method doesn't have any conv_cache argument.
| if self.model_class.__name__ in [ | ||
| "UNetSpatioTemporalConditionModel", | ||
| "AutoencoderKLTemporalDecoder", | ||
| ]: | ||
| return |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Because these are supported.
|
@a-r-r-o-w @DN6 a gentle ping. |
|
@a-r-r-o-w merging this to unblock you and will let you add any left over tests. Hopefully, that is okay. |
* add: autoencoderkl tests * autoencodertiny. * fix * asymmetric autoencoder. * more * integration tests for stable audio decoder. * consistency decoder vae tests * remove grad check from consistency decoder. * cog * bye test_models_vae.py * fix * fix * remove allegro * fixes * fixes * fixes --------- Co-authored-by: Dhruv Nair <dhruv.nair@gmail.com>
What does this PR do?
Internal thread: https://huggingface.slack.com/archives/C065E480NN9/p1730203711189419.
Tears apart
test_models_vae.pyto break the tests in accordance with the Autoencoder model classes we have undersrc/diffusers/models/autoencoders.Didn't include Allegro as it's undergoing some refactoring love from Aryan. Discussed internally.
Some comments inline.