KEMBAR78
Fix flux controlnet mode to take into account batch size by christopher-beckham · Pull Request #9406 · huggingface/diffusers · GitHub
Skip to content

Conversation

@christopher-beckham
Copy link
Contributor

@christopher-beckham christopher-beckham commented Sep 10, 2024

What does this PR do?

For some reason the control_mode in the Flux ControlNet pipeline does not take into account the batch size, which results in a downstream error to the call to controlnet due to a shape mismatch. (This is assuming that control_mode is an int, but the docstring implies that it's a valid configuration, i.e. Optional[Union[int, List[int]]] = None.)

Traceback (most recent call last):
  File "/home/chris/github/test-controlnet/test.py", line 98, in <module>
    image = pipe(
  File "/home/chris/.conda/envs/me/lib/python3.9/site-packages/torch/utils/_contextlib.py", line 116, in decorate_context
    return func(*args, **kwargs)
  File "/home/chris/github/diffusers/src/diffusers/pipelines/flux/pipeline_flux_controlnet.py", line 804, in __call__
    num_channels_latents = self.transformer.config.in_channels // 4
  File "/home/chris/.conda/envs/me/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/chris/.conda/envs/me/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/chris/.conda/envs/me/lib/python3.9/site-packages/accelerate/hooks.py", line 169, in new_forward
    output = module._old_forward(*args, **kwargs)
  File "/home/chris/github/diffusers/src/diffusers/models/controlnet_flux.py", line 295, in forward
    encoder_hidden_states = torch.cat([controlnet_mode_emb, encoder_hidden_states], dim=1)
RuntimeError: Sizes of tensors must match except in dimension 1. Expected size 1 but got size 2 for tensor number 1 in the list.

This is because of the following:

# set control mode
if control_mode is not None:
control_mode = torch.tensor(control_mode).to(device, dtype=torch.long)
control_mode = control_mode.reshape([-1, 1])

i.e. no matter what the effective batch size is, when control_mode is passed as an int it will always be of the shape (1,1). Therefore, my code does the following:

  • First cast control_mode into a torch long tensor and reshape it to have an extra axis.
  • If control_mode was already a list to begin with, it will now be of shape (n, 1). (And we assume n == control_image.shape[0].)
  • If the original type was actually an int, just repeat it control_mode on the batch axis to go from (1,1) to obtain the final shape (control_image.shape[0],1).

I haven't had the time to test MultiControlNet this morning but it also needs similar logic. And I'm also confused about this part of the code:

control_mode_ = []
if isinstance(control_mode, list):
for cmode in control_mode:
if cmode is None:
control_mode_.append(-1)
else:
control_mode_.append(cmode)

If control_mode is a list of ints and None's then the None's will get converted to -1 (which is necessary since we can't have a torch tensor with mixed dtypes). But as far as I can tell you also can't index into the necessary embedding with -1 since that will give a cuda index assertion error. So I don't know why it's coded up this way in the first place, shouldn't we just throw an exception if control_mode (as a list) has Nones or non-ints?

The PR is not complete but it's a start, it would be good to improve and understand how exactly we want to do the conditioning here.

Before submitting

Who can review?

@haofanwang @wangqixun possibly?

@yiyixuxu
Copy link
Collaborator

@christopher-beckham
thanks for the PR! would you be able to provide a script that currently fail but will pass with this PR? - that will really help!

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@christopher-beckham
Copy link
Contributor Author

christopher-beckham commented Sep 12, 2024

Hi @yiyixuxu

Sure thing. The core of it is the following (taken from the example code here):

import torch
from diffusers.utils import load_image
from diffusers import FluxControlNetPipeline, FluxControlNetModel
from diffusers import AutoencoderKL

if __name__ == '__main__':

    control_image = load_image(
        "https://huggingface.co/InstantX/FLUX.1-dev-Controlnet-Union-alpha/resolve/main/images/canny.jpg"
    ).resize((512,512))
    controlnet_conditioning_scale = 0.5
    control_mode = 0

    width, height = control_image.size

    base_model = 'black-forest-labs/FLUX.1-dev'
    controlnet_model = 'InstantX/FLUX.1-dev-Controlnet-Union'

    controlnet = FluxControlNetModel.from_pretrained(
        controlnet_model, torch_dtype=torch.bfloat16
    )
    vae = AutoencoderKL.from_pretrained(base_model, subfolder="vae", torch_dtype=torch.bfloat16)
    pipe = FluxControlNetPipeline.from_pretrained(
        base_model, vae=vae, controlnet=controlnet, torch_dtype=torch.bfloat16
    )
    print("to cuda...")
    pipe.to("cuda")

    prompt = 'A bohemian-style female travel blogger with sun-kissed skin and messy beach waves.'

    pipe.enable_model_cpu_offload()

    generator = torch.Generator(device="cuda").manual_seed(0)

This should fail pre-fix because we're passing in an int for the control mode:

    image = pipe(
        prompt, 
        num_images_per_prompt=2,
        control_image=control_image,
        control_mode=0,
        width=width,
        height=height,
        controlnet_conditioning_scale=controlnet_conditioning_scale,
        num_inference_steps=1, 
        guidance_scale=3.5,
        generator=generator
    ).images[0]

I don't have time to test the MultiControlNet just now (maybe tonight), but I would strongly presume that if you pass in a list of control modes and any of them are -ve then it should throw a cuda error (since you cannot index into an embedding layer with a -ve index). Example below (though this isn't for multi control net)

    image = pipe(
        prompt, 
        num_images_per_prompt=2,
        control_image=control_image,
        control_mode=-1,
        width=width,
        height=height,
        controlnet_conditioning_scale=controlnet_conditioning_scale,
        num_inference_steps=1, 
        guidance_scale=3.5,
        generator=generator
    ).images[0]

@yiyixuxu
Copy link
Collaborator

thanks @christopher-beckham
controlnet mode is only needed for union controlnet no? do we need to use MultiControlNet with union?
cc @asomoza here too

)

# set control mode
orig_mode_type = type(control_mode)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if we only accept int for single controlnet, maybe raise a value error here

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes good idea

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There does already appear to be a check for it here inside the controlnet class itself:

if controlnet_mode is None:
raise ValueError("`controlnet_mode` cannot be `None` when applying ControlNet-Union")

@asomoza
Copy link
Member

asomoza commented Sep 13, 2024

controlnet mode is only needed for union controlnet no? do we need to use MultiControlNet with union?

The original one worked like that but you pass a tensor with the modes and an image list, so you don't need the multicontrolnet option.

For what I know of the one made for Flux, it's not the same, it doesn't allow to pass multiple modes and images to the same controlnet so this one would need the multicontrolnet option if you want to use more than one condition to guide the generation.

@christopher-beckham
Copy link
Contributor Author

controlnet mode is only needed for union controlnet no? do we need to use MultiControlNet with union?

yes. Though confusingly, one could do MultiControlNet where the controlnets passed are actually union controlnets. That seems to be what is assumed in the code because when you call the MultiControlNet you're able to pass in a control mode. For instance, if none of the controlnets in the MultiControlNet were of the "union" variety, you would just be doing controlnet_1(image) + ... + controlnet_N(image)

I think maybe is a good time for me to clarify my own understanding of what's going on since I am not 100% sure. Please correct me if I've misunderstood:

CN-Union

ControlNet-Union is specifically a type of controlnet which is trained to handle multiple input conditions due to the extra conditioning tensor which gets passed in when it is trained. At generation time, it basically allows one to pass in arbitrary input conditions as long as you tell the controlnet what it use via control_mode.

This means that you can do the following (written in pseudocode):

block_outs_1 = controlnet_union(control_image_1, mode_1)
pred_noise_1 = transformer(..., block_outs_1)
...
block_outs_2 = controlnet_union(control_image_N, mode_N)
pred_noise_2 = transformer(..., block_outs_2)

In the FluxPipeline, it should therefore be possible to batch this and do pipeline(prompt, control_images, controlnet_modes), i.e. control_images = [control_image_1, ..., control_image_N] and controlnet_modes = [mode_1, ..., mode_N]. (Let's just assume for this example num_images_per_prompt==1, because if it's > 1 then we need to do some extra work in the pipeline so that things still make sense.)

It should also be possible to do pipeline(control_images, mode_int), in which case it is assumed you want to generate from multiple control images but use the same mode. (What my PR proposed initially.)

MultiControlNet

MultiControlNet basically assumes a list of control nets, but -- perhaps confusingly -- these controlnets can also be union controlnets (in which case they each support controlnet_mode's, or they can just be regular uni-conditional controlnets and therefore require that controlnet_mode=None). But essentially the purpose is to let you do the following:

block_samples_1 = controlnet_1(control_image_1, mode_1, scale_1)
...
block_samples_N = controlnet_N(control_image_N, mode_N, scale_N)

block_samples_combined = block_samples_1 + ... + block_samples_N
pred_noise = transformer(..., block_samples_combined)

if any of the controlnet_i (i = 1...N) models are non-union, we expect mode_i==None, otherwise they may also be integers.

(Somewhat coincidentally, block_samples_combined is basically a "union" as well, but it carries a different semantic meaning to what is going on with union controlnets.)

In both cases

In both cases, it seems reasonable to assume that the number of control modes (i.e. length of controlnet_mode) should be == the number of control images. But in the case of MultiControlNet we can allow that the list of modes contain None if need be.

I will test all of these things out later today.

@christopher-beckham christopher-beckham force-pushed the christopher-beckham/fix_flux_controlnet_modes branch from 81bc534 to 370f382 Compare September 16, 2024 20:54
@christopher-beckham
Copy link
Contributor Author

I've been testing the multi controlnet stuff out as well. Here's the code segment:

control_image = load_image(
    "https://huggingface.co/InstantX/FLUX.1-dev-Controlnet-Union-alpha/resolve/main/images/canny.jpg"
).resize((512,512))

width, height = control_image.size

base_model = 'black-forest-labs/FLUX.1-dev'
controlnet_union = FluxControlNetModel.from_pretrained(
    'InstantX/FLUX.1-dev-Controlnet-Union', torch_dtype=torch.bfloat16
)
controlnet_depth = FluxControlNetModel.from_pretrained(
    "Shakker-Labs/FLUX.1-dev-ControlNet-Depth", torch_dtype=torch.bfloat16
)

multinet = FluxMultiControlNetModel([controlnet_union, controlnet_depth])
vae = AutoencoderKL.from_pretrained(base_model, subfolder="vae", torch_dtype=torch.bfloat16)
pipe = FluxControlNetPipeline.from_pretrained(
    base_model, vae=vae, controlnet=multinet, torch_dtype=torch.bfloat16
)
#pipe.to("cuda")
pipe.enable_model_cpu_offload()

pipe(
    control_image=[control_image, control_image], 
    control_mode=[0, None], 
    controlnet_conditioning_scale=[1., 1.]
)

So I was looking at this again:

control_mode_ = []
if isinstance(control_mode, list):
for cmode in control_mode:
if cmode is None:
control_mode_.append(-1)
else:
control_mode_.append(cmode)

About the 'negative indices', for a MultiControlNet this shouldn't raise any issues -- for instance, if we have a multi controlnet where the first controlnet is union and the second is "regular" (see my above code), if we passed control_mode as [0, None] it would get internally converted to [0, -1] but for the latter controlnet this if statement wouldn't get executed, so it would not raise a device side assert:

if self.union:
# union mode
if controlnet_mode is None:
raise ValueError("`controlnet_mode` cannot be `None` when applying ControlNet-Union")
# union mode emb
controlnet_mode_emb = self.controlnet_mode_embedder(controlnet_mode)
encoder_hidden_states = torch.cat([controlnet_mode_emb, encoder_hidden_states], dim=1)
txt_ids = torch.cat([txt_ids[:1], txt_ids], dim=0)

However, it just makes me wonder why we have that code in the first place -- this wouldn't be a bug fix, rather just cleaning up code which may not ultimately be necessary. Thoughts?

@christopher-beckham
Copy link
Contributor Author

Lastly, if we try and do this (just for a regular controlnet, not multi):

pipe(control_image=[control_image, control_image], control_mode=0, num_images_per_prompt=1)

We get an error like RuntimeError: shape '[1, 16, 16, 2, 16, 2]' is invalid for input of size 32768 in the call to this self._pack_latents here:

height_control_image, width_control_image = control_image.shape[2:]
control_image = self._pack_latents(
control_image,
batch_size * num_images_per_prompt,
num_channels_latents,
height_control_image,
width_control_image,
)

Basically, this code:

# 2. Define call parameters
if prompt is not None and isinstance(prompt, str):
batch_size = 1
elif prompt is not None and isinstance(prompt, list):
batch_size = len(prompt)
else:
batch_size = prompt_embeds.shape[0]

is not taking into account the fact that control_image can be a List[PIL.Image.Image] (at least according to the docstring). While it would be easy to hack up a fix for this, it gets rather complicated thinking about all the degrees of freedom here with regard to batching: apparently the prompt can be a list, the control image can be a list, and we also have num_images_per_prompt. Is the actual internal "batch size" meant to be the product of all these things?*

(Apparently it can be any one of torch.Tensor, PIL.Image.Image, np.ndarray, List[torch.Tensor], List[PIL.Image.Image], List[np.ndarray],List[List[torch.Tensor]], List[List[np.ndarray]] or List[List[PIL.Image.Image]])

It seems like a simple fix for now is to just enforce that len(prompts) == len(control_image).

Thoughts @yiyixuxu @asomoza ? Thanks.

Copy link
Collaborator

@yiyixuxu yiyixuxu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks! I understand this a little bit better now!
left some comments

)

# set control mode
orig_mode_type = type(control_mode)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we should accept both int or a list (of length 1) int

Suggested change
orig_mode_type = type(control_mode)
if isinstance(control_mode, list):
control_mode = [control_mode]
if len(control_mode) > 1:
raise ValueError(" For `FluxControlNet`, `control_mode` should be an `int` or a list contain 1 `int`")

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think you meant to do the opposite here? i.e. if not isinstance(control_mode, list) then convert control_mode into a singleton list?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oops! yes!

Comment on lines 753 to 755
control_mode = torch.tensor(control_mode).to(device, dtype=torch.long).view(-1,1)
if orig_mode_type == int:
control_mode = control_mode.repeat(control_image.shape[0], 1)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
control_mode = torch.tensor(control_mode).to(device, dtype=torch.long).view(-1,1)
if orig_mode_type == int:
control_mode = control_mode.repeat(control_image.shape[0], 1)
control_mode = torch.tensor(control_mode).to(device, dtype=torch.long)
control_model = control_mode.view(-1,1).expand(control_image.shape[0], 1)

Comment on lines 798 to 801
control_mode = torch.tensor(control_mode_).to(device, dtype=torch.long)
control_mode = control_mode.view(-1, 1)
else:
raise ValueError("For multi-controlnet, control_mode should be a list")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we need to make sure each control_mode has batch_size too for multi-controlnet

Suggested change
control_mode = torch.tensor(control_mode_).to(device, dtype=torch.long)
control_mode = control_mode.view(-1, 1)
else:
raise ValueError("For multi-controlnet, control_mode should be a list")
control_mode = torch.tensor(control_mode_).to(device, dtype=torch.long)
control_mode = control_mode.view(-1, 1).expand(control_images[0].shape[0]
else:
raise ValueError("For multi-controlnet, control_mode should be a list")

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok I made this change as well. (Will commit in a sec)

However, unlike the regular controlnet if block I do not explicitly allow for control_mode to be an int here, i.e. it won't automagically be converted to a singleton list. I think maybe in the "multi" case it's best to be explicit -- and remind the user -- to think in terms of it being a list, even if the multi controlnet object just has one controlnet contained inside it.

@yiyixuxu
Copy link
Collaborator

thanks for your investigation!!
Yes I think we should start to have some validation, and make it clear how inputs format should be when we want to use multiple images + controlnet + num_images_per_prompt(currently very confusing!)
but since we are not doing any of these checks yet maybe we can do it in a separate PR, and probably across all our controllers

@christopher-beckham
Copy link
Contributor Author

Thanks for your suggestions! I just made a commit incorporating your changes manually (not sure what the best practice is, i.e. whether I should have committed your changes in directly then modified them, but anyways the changes are made).

wrt to my changes, for multi-controlnet, the control mode must be passed as a list, even if it's a singleton. I didn't want to automagically convert it from int to a singleton list (like in the single controlnet case) since it seems for fitting to be explicit about it in the multi case.

Yes I think we should start to have some validation, and make it clear how inputs format should be when we want to use multiple images + controlnet + num_images_per_prompt(currently very confusing!) but since we are not doing any of these checks yet maybe we can do it in a separate PR, and probably across all our controllers

Yeah maybe for another PR we can explore this. When I looked at the SDXL controlnet class, it basically looks like the prompt is the "ground truth" for how batching should be handled -- that is, the control images variable (whether it be a List[PIL.Image] or torch.Tensor or np.ndarray) is checked to make sure its length (batch size) should be equivalent to the number of prompts. After that has been satisfied then you can then handle num_images_per_prompt to double/triple/whatever the effective batch size. A lot of that logic can be ported over here.

Lastly, I have some unit tests I used here:

https://github.com/christopher-beckham/diffusers-tests/blob/master/flux_controlnet/test_multicontrolnet.py

control_mode = torch.tensor(
[-1 if elem is None else elem for elem in control_mode]
)
control_mode = control_mode.view(-1,1).expand(len(control_image), 1)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is this right though? for multi controlnet, we loop through each controlnet, control_image and control_mode, and control_image element and control_mode does not have same batch_size

for i, (image, mode, scale) in enumerate(zip(controlnet_cond, controlnet_mode, conditioning_scale)):

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For the following setup:

controlnet_union = FluxControlNetModel.from_pretrained(
    'InstantX/FLUX.1-dev-Controlnet-Union', torch_dtype=torch.bfloat16
)
controlnet_depth = FluxControlNetModel.from_pretrained(
    "Shakker-Labs/FLUX.1-dev-ControlNet-Depth", torch_dtype=torch.bfloat16
)

multinet = FluxMultiControlNetModel([controlnet_union, controlnet_depth])
vae = AutoencoderKL.from_pretrained(base_model, subfolder="vae", torch_dtype=torch.bfloat16)
pipe = FluxControlNetPipeline.from_pretrained(
    base_model, vae=vae, controlnet=multinet, torch_dtype=torch.bfloat16
)

def pil_to_numpy(image):
    """to (c,h,w)"""
    return (np.array(image).astype(np.float32)/255.).swapaxes(1,2).swapaxes(0,1)

def pil_to_torch(image):
    return torch.from_numpy(pil_to_numpy(image)).float()

control_image = load_image(
    "https://huggingface.co/InstantX/FLUX.1-dev-Controlnet-Union-alpha/resolve/main/images/canny.jpg"
).resize((512,512))

If you're asking whether passing in an actual tensor for control_image would work (instead of PIL), i.e:

# this will have shape (3,512,512)
control_image = pil_to_torch(control_image)
# this will have shape (4,3,512,512), i.e. batch size of 4
control_image = control_image.unsqueeze(0).repeat(4,1,1,1)

pipe(
    # the controlnets here are [union, union]
    control_image=[control_image, control_image],  # each inner control_image is batched
    control_mode=[0, None], 
    controlnet_conditioning_scale=[1., 1.]
)

We will get an error at _pack_latents, again this is related to the fact that the pipeline completely ignores what the batch size of control_image. We initially mentioned fixing this in a future PR but it also depends on what other pipelines have this issue. I looked at SDXL controlnet and it seems to not have this issue. It wouldn't be unreasonable to pursue fixing it in this PR but we can do a new PR if you prefer that.

Also, I can clarify this works:

# This is a PIL image
control_image = obj["control_image"]

pipe(
    control_image=[control_image, control_image], 
    control_mode=[0, None], 
    controlnet_conditioning_scale=[1., 1.]
)

@yiyixuxu
Copy link
Collaborator

yiyixuxu commented Sep 23, 2024

hi @christopher-beckham
I continued this PR on #9507 (it is branched out from your PR so all your commits are there and you're a co-authtor there already)
would be really appreciate if you can give a review!!

I made a slow test with 4 test cases here #9507 (comment)

I think we do not need to consider something like multinet = FluxMultiControlNetModel([controlnet_union, controlnet_depth]) because:

  1. it is not very meaningful (union controlnet already has a mode for depth condition)
  2. it is currently no checkpoint to make it possible to use it this way, the only regular controlnet checkpoint we have is canny, which does not have the single blocks so it cannot be used together with union at all

@christopher-beckham
Copy link
Contributor Author

Thanks @yiyixuxu I'm gonna check it out shortly :)

@github-actions
Copy link
Contributor

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Please note that issues that do not follow the contributing guidelines are likely to be ignored.

@github-actions github-actions bot added the stale Issues that haven't received updates label Oct 20, 2024
@a-r-r-o-w a-r-r-o-w removed the stale Issues that haven't received updates label Oct 20, 2024
@yiyixuxu
Copy link
Collaborator

closing this since PR is merged!

@yiyixuxu yiyixuxu closed this Oct 22, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants