KEMBAR78
adds the pipeline for pixart alpha controlnet by raulc0399 · Pull Request #8857 · huggingface/diffusers · GitHub
Skip to content

Conversation

raulc0399
Copy link
Contributor

@raulc0399 raulc0399 commented Jul 12, 2024

this PR adds the controlnet pipeline for the pixart alpha diffusion model

the following example uses the HED edge to control the generation.

import torch
import torchvision.transforms as T
import torchvision.transforms.functional as TF

from diffusers.models import PixArtControlNetAdapterModel
from diffusers.pipelines import PixArtAlphaControlnetPipeline, get_closest_hw
import PIL.Image as Image

from controlnet_aux import HEDdetector

input_image_path = "asset/images/controlnet/car.jpg"
given_image = Image.open(input_image_path)

path_to_controlnet = "raulc0399/pixart-alpha-hed-controlnet"
prompt = "modern car, city in background, clear sky, suny day"

weight_dtype = torch.float16
image_size = 1024

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

controlnet = PixArtControlNetAdapterModel.from_pretrained(
    path_to_controlnet,
    torch_dtype=weight_dtype,
    use_safetensors=True,
).to(device)

pipe = PixArtAlphaControlnetPipeline.from_pretrained(
    "PixArt-alpha/PixArt-XL-2-1024-MS",
    controlnet=controlnet,
    torch_dtype=weight_dtype,
    use_safetensors=True,
).to(device)

# preprocess image, generate HED edge
hed = HEDdetector.from_pretrained("lllyasviel/Annotators")

width, height = get_closest_hw(given_image.size[0], given_image.size[1], image_size)

condition_transform = T.Compose([
    T.Lambda(lambda img: img.convert('RGB')),
    T.Resize(int(min(height, width))),
    T.CenterCrop([int(height), int(width)]),
    T.ToTensor()
])

control_image = condition_transform(control_image)
hed_edge = hed(control_image, detect_resolution=image_size, image_resolution=image_size)

with torch.no_grad():
    out = pipe(
        prompt=prompt,
        image=hed_edge,
        num_inference_steps=14,
        guidance_scale=4.5,
        height=image_size,
        width=image_size,
    )

    out.images[0].save(f"./output.jpg")

here some images: original image, control image and generated image

Who can review?

@yiyixuxu @lawrence-cj

@yiyixuxu
Copy link
Collaborator

is this the checkpoint? https://huggingface.co/PixArt-alpha/PixArt-ControlNet
I don't see any downloads, not sure if it's tracking correctly

is this pixart alpha controlnet used a lot in the community? if not, maybe we can make a community pipeline to start with?

also cc @asomoza

@raulc0399
Copy link
Contributor Author

@yiyixuxu that is the pixart controlnet model for HED conditioning as uploaded by the authors of pixart.
for this pipeline i have converted the controlnet layers to safetensors, uploaded here https://huggingface.co/raulc0399/pixart-alpha-hed-controlnet

they can be used with this pipeline

@asomoza
Copy link
Member

asomoza commented Jul 17, 2024

why does it have its own implementation of the HED detector? It doesn't work with the regular one that everyone uses? Have you tested it with the one from the controlnet_aux library?

@raulc0399
Copy link
Contributor Author

@asomoza

why does it have its own implementation of the HED detector? It doesn't work with the regular one that everyone uses? Have you tested it with the one from the controlnet_aux library?

the sample above just used the HED class that the authors had in their repository, and that was used to train their HED controlnet.

but i just checked it it seems to be the same, or better said adapted, from the controlnet_aux

@asomoza
Copy link
Member

asomoza commented Jul 17, 2024

thanks, I'll give it a test later. I was asking because if it was trained with a custom HED detector which produces different results than the default one it will be really hard for people to use it.

It would be nice if you could post some results (images) in the PR description.

@raulc0399
Copy link
Contributor Author

raulc0399 commented Jul 17, 2024

thanks, I'll give it a test later. I was asking because if it was trained with a custom HED detector which produces different results than the default one it will be really hard for people to use it.

using the HED from control_aux it looses some quality.
will try some more tests with that one.

but i also have a training script that i am testing before creating a PR:
https://github.com/raulc0399/PixArt-alpha/blob/master_train_controlnet_diffusers/controlnet/train_pixart_controlnet_hf.py

that can be used to train further models.

It would be nice if you could post some results (images) in the PR description.

will do.

@raulc0399
Copy link
Contributor Author

i have to correct my previous comment. i was using the default params for HED, which converted the image to 512, if i use however 1024 it works as it should.

@asomoza
Copy link
Member

asomoza commented Jul 18, 2024

Thanks, the results looks nice, since we only have one controlnet, maybe do what @yiyixuxu suggested, lets start with a community pipeline first and then as it gets traction and we have more controlnets move it to core.

@raulc0399
Copy link
Contributor Author

@asomoza

Thanks, the results looks nice, since we only have one controlnet, maybe do what @yiyixuxu suggested, lets start with a community pipeline first and then as it gets traction and we have more controlnets move it to core.

ok, i will move it to the examples folder and put there the training script as well.
i have done some initial tests on the "fusing/fill50k" dataset to validate it works

@raulc0399
Copy link
Contributor Author

@yiyixuxu @asomoza
have moved all to the examples folder
i have also added the training script. together with sh files for starting the training and for running the pipeline

@@ -0,0 +1,292 @@
from typing import Any, Dict, Optional
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think maybe pipelines can go to the /example/community folder, the training script can stay in example/pixart folder

cc @sayakpaul

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Okay with that plan.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This needs to be addressed first.

@raulc0399
Copy link
Contributor Author

@yiyixuxu
the last commit moves the pipeline and the example on how to run it to examples/community

Copy link
Collaborator

@yiyixuxu yiyixuxu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I left some comments, thanks!

from diffusers.models.modeling_utils import ModelMixin
from diffusers.models.modeling_outputs import Transformer2DModelOutput

class PixArtControlNetAdapterBlock(nn.Module):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we need to copy paste all the model code here into the pipeline so that the pipeline will be able to run, no?

Copy link
Contributor Author

@raulc0399 raulc0399 Jul 24, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the pipeline code changes the sys path, so it runs
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))

@@ -0,0 +1,81 @@
import sys
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i have added the section

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this file should not be here, no? should be in the same folder as others?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

just moved it

@github-actions
Copy link
Contributor

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Please note that issues that do not follow the contributing guidelines are likely to be ignored.

@github-actions github-actions bot added the stale Issues that haven't received updates label Sep 14, 2024
@yiyixuxu yiyixuxu removed the stale Issues that haven't received updates label Sep 17, 2024
@yiyixuxu
Copy link
Collaborator

@sayakpaul can you take a look to see if we can merge this now?

@yiyixuxu
Copy link
Collaborator

@raulc0399 can you run make style?

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@github-actions github-actions bot added the stale Issues that haven't received updates label Oct 13, 2024
@yiyixuxu yiyixuxu removed the stale Issues that haven't received updates label Oct 15, 2024
@yiyixuxu
Copy link
Collaborator

gentle pin @raulc0399 - are we interested in moving this to research folder?

@raulc0399
Copy link
Contributor Author

@yiyixuxu yes will do, sorry for the delay.

@raulc0399
Copy link
Contributor Author

@yiyixuxu i have moved the pipeline and training script under research_projects/pixart

@yiyixuxu yiyixuxu requested a review from sayakpaul October 17, 2024 22:36
@yiyixuxu
Copy link
Collaborator

@sayakpaul does this look good to you now?

Copy link
Member

@sayakpaul sayakpaul left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks so much!

Could you run make style && make quality?

@lawrence-cj
Copy link
Contributor

May I request for the permission to push commit to raulc0399:main_pixart_alpha_controlnet, so that I can help do the make style and make quality. @raulc0399

@raulc0399
Copy link
Contributor Author

May I request for the permission to push commit to raulc0399:main_pixart_alpha_controlnet, so that I can help do the make style and make quality. @raulc0399

@lawrence-cj sure

@lawrence-cj
Copy link
Contributor

ERROR: Permission to raulc0399/diffusers.git denied to lawrence-cj. Could not read from remote repository. Please make sure you have the correct access rights and the repository exists.

Seems I still cannot push commit to your branch.

@raulc0399
Copy link
Contributor Author

@lawrence-cj i invited you just now as collaborator

@lawrence-cj
Copy link
Contributor

lawrence-cj commented Oct 28, 2024

Cool. @raulc0399. already run make style && make quality.
Gentle ping yiyi @yiyixuxu .

@yiyixuxu yiyixuxu merged commit c5376c5 into huggingface:main Oct 28, 2024
8 checks passed
@yiyixuxu
Copy link
Collaborator

thank you @lawrence-cj @raulc0399

@lawrence-cj
Copy link
Contributor

Thank you so much. Respect. @raulc0399 @sayakpaul @yiyixuxu

sayakpaul added a commit that referenced this pull request Dec 23, 2024
* add the controlnet pipeline for pixart alpha

---------

Co-authored-by: YiYi Xu <yixu310@gmail.com>
Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
Co-authored-by: junsongc <cjs1020440147@icloud.com>
@chaewon-huh
Copy link

Thanks for adding this. I really needed it and will make good use of it!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

7 participants