-
Notifications
You must be signed in to change notification settings - Fork 6.4k
Fix flux controlnet mode to take into account batch size #9406
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Fix flux controlnet mode to take into account batch size #9406
Conversation
@christopher-beckham |
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
Hi @yiyixuxu Sure thing. The core of it is the following (taken from the example code here):
This should fail pre-fix because we're passing in an int for the control mode:
I don't have time to test the MultiControlNet just now (maybe tonight), but I would strongly presume that if you pass in a list of control modes and any of them are -ve then it should throw a cuda error (since you cannot index into an embedding layer with a -ve index). Example below (though this isn't for multi control net)
|
thanks @christopher-beckham |
) | ||
|
||
# set control mode | ||
orig_mode_type = type(control_mode) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
if we only accept int for single controlnet, maybe raise a value error here
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yes good idea
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There does already appear to be a check for it here inside the controlnet class itself:
diffusers/src/diffusers/models/controlnet_flux.py
Lines 289 to 290 in b52119a
if controlnet_mode is None: | |
raise ValueError("`controlnet_mode` cannot be `None` when applying ControlNet-Union") |
The original one worked like that but you pass a tensor with the modes and an image list, so you don't need the multicontrolnet option. For what I know of the one made for Flux, it's not the same, it doesn't allow to pass multiple modes and images to the same controlnet so this one would need the multicontrolnet option if you want to use more than one condition to guide the generation. |
yes. Though confusingly, one could do I think maybe is a good time for me to clarify my own understanding of what's going on since I am not 100% sure. Please correct me if I've misunderstood: CN-UnionControlNet-Union is specifically a type of controlnet which is trained to handle multiple input conditions due to the extra conditioning tensor which gets passed in when it is trained. At generation time, it basically allows one to pass in arbitrary input conditions as long as you tell the controlnet what it use via This means that you can do the following (written in pseudocode):
In the It should also be possible to do MultiControlNet
if any of the (Somewhat coincidentally, In both casesIn both cases, it seems reasonable to assume that the number of control modes (i.e. length of I will test all of these things out later today. |
81bc534
to
370f382
Compare
I've been testing the multi controlnet stuff out as well. Here's the code segment:
So I was looking at this again: diffusers/src/diffusers/pipelines/flux/pipeline_flux_controlnet.py Lines 789 to 795 in f28a8c2
About the 'negative indices', for a MultiControlNet this shouldn't raise any issues -- for instance, if we have a multi controlnet where the first controlnet is union and the second is "regular" (see my above code), if we passed diffusers/src/diffusers/models/controlnet_flux.py Lines 287 to 294 in f28a8c2
However, it just makes me wonder why we have that code in the first place -- this wouldn't be a bug fix, rather just cleaning up code which may not ultimately be necessary. Thoughts? |
Lastly, if we try and do this (just for a regular controlnet, not multi):
We get an error like diffusers/src/diffusers/pipelines/flux/pipeline_flux_controlnet.py Lines 741 to 748 in f28a8c2
Basically, this code: diffusers/src/diffusers/pipelines/flux/pipeline_flux_controlnet.py Lines 693 to 699 in f28a8c2
is not taking into account the fact that (Apparently it can be any one of It seems like a simple fix for now is to just enforce that |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
thanks! I understand this a little bit better now!
left some comments
) | ||
|
||
# set control mode | ||
orig_mode_type = type(control_mode) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we should accept both int
or a list (of length 1) int
orig_mode_type = type(control_mode) | |
if isinstance(control_mode, list): | |
control_mode = [control_mode] | |
if len(control_mode) > 1: | |
raise ValueError(" For `FluxControlNet`, `control_mode` should be an `int` or a list contain 1 `int`") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think you meant to do the opposite here? i.e. if not isinstance(control_mode, list)
then convert control_mode
into a singleton list?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
oops! yes!
control_mode = torch.tensor(control_mode).to(device, dtype=torch.long).view(-1,1) | ||
if orig_mode_type == int: | ||
control_mode = control_mode.repeat(control_image.shape[0], 1) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
control_mode = torch.tensor(control_mode).to(device, dtype=torch.long).view(-1,1) | |
if orig_mode_type == int: | |
control_mode = control_mode.repeat(control_image.shape[0], 1) | |
control_mode = torch.tensor(control_mode).to(device, dtype=torch.long) | |
control_model = control_mode.view(-1,1).expand(control_image.shape[0], 1) |
control_mode = torch.tensor(control_mode_).to(device, dtype=torch.long) | ||
control_mode = control_mode.view(-1, 1) | ||
else: | ||
raise ValueError("For multi-controlnet, control_mode should be a list") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we need to make sure each control_mode has batch_size too for multi-controlnet
control_mode = torch.tensor(control_mode_).to(device, dtype=torch.long) | |
control_mode = control_mode.view(-1, 1) | |
else: | |
raise ValueError("For multi-controlnet, control_mode should be a list") | |
control_mode = torch.tensor(control_mode_).to(device, dtype=torch.long) | |
control_mode = control_mode.view(-1, 1).expand(control_images[0].shape[0] | |
else: | |
raise ValueError("For multi-controlnet, control_mode should be a list") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ok I made this change as well. (Will commit in a sec)
However, unlike the regular controlnet if block I do not explicitly allow for control_mode
to be an int here, i.e. it won't automagically be converted to a singleton list. I think maybe in the "multi" case it's best to be explicit -- and remind the user -- to think in terms of it being a list, even if the multi controlnet object just has one controlnet contained inside it.
thanks for your investigation!! |
…p control mode handling for multi case
Thanks for your suggestions! I just made a commit incorporating your changes manually (not sure what the best practice is, i.e. whether I should have committed your changes in directly then modified them, but anyways the changes are made). wrt to my changes, for multi-controlnet, the control mode must be passed as a list, even if it's a singleton. I didn't want to automagically convert it from int to a singleton list (like in the single controlnet case) since it seems for fitting to be explicit about it in the multi case.
Yeah maybe for another PR we can explore this. When I looked at the SDXL controlnet class, it basically looks like the prompt is the "ground truth" for how batching should be handled -- that is, the control images variable (whether it be a Lastly, I have some unit tests I used here: |
control_mode = torch.tensor( | ||
[-1 if elem is None else elem for elem in control_mode] | ||
) | ||
control_mode = control_mode.view(-1,1).expand(len(control_image), 1) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
is this right though? for multi controlnet, we loop through each controlnet, control_image and control_mode, and control_image element and control_mode does not have same batch_size
for i, (image, mode, scale) in enumerate(zip(controlnet_cond, controlnet_mode, conditioning_scale)): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For the following setup:
controlnet_union = FluxControlNetModel.from_pretrained(
'InstantX/FLUX.1-dev-Controlnet-Union', torch_dtype=torch.bfloat16
)
controlnet_depth = FluxControlNetModel.from_pretrained(
"Shakker-Labs/FLUX.1-dev-ControlNet-Depth", torch_dtype=torch.bfloat16
)
multinet = FluxMultiControlNetModel([controlnet_union, controlnet_depth])
vae = AutoencoderKL.from_pretrained(base_model, subfolder="vae", torch_dtype=torch.bfloat16)
pipe = FluxControlNetPipeline.from_pretrained(
base_model, vae=vae, controlnet=multinet, torch_dtype=torch.bfloat16
)
def pil_to_numpy(image):
"""to (c,h,w)"""
return (np.array(image).astype(np.float32)/255.).swapaxes(1,2).swapaxes(0,1)
def pil_to_torch(image):
return torch.from_numpy(pil_to_numpy(image)).float()
control_image = load_image(
"https://huggingface.co/InstantX/FLUX.1-dev-Controlnet-Union-alpha/resolve/main/images/canny.jpg"
).resize((512,512))
If you're asking whether passing in an actual tensor for control_image
would work (instead of PIL), i.e:
# this will have shape (3,512,512)
control_image = pil_to_torch(control_image)
# this will have shape (4,3,512,512), i.e. batch size of 4
control_image = control_image.unsqueeze(0).repeat(4,1,1,1)
pipe(
# the controlnets here are [union, union]
control_image=[control_image, control_image], # each inner control_image is batched
control_mode=[0, None],
controlnet_conditioning_scale=[1., 1.]
)
We will get an error at _pack_latents
, again this is related to the fact that the pipeline completely ignores what the batch size of control_image
. We initially mentioned fixing this in a future PR but it also depends on what other pipelines have this issue. I looked at SDXL controlnet and it seems to not have this issue. It wouldn't be unreasonable to pursue fixing it in this PR but we can do a new PR if you prefer that.
Also, I can clarify this works:
# This is a PIL image
control_image = obj["control_image"]
pipe(
control_image=[control_image, control_image],
control_mode=[0, None],
controlnet_conditioning_scale=[1., 1.]
)
hi @christopher-beckham I made a slow test with 4 test cases here #9507 (comment) I think we do not need to consider something like
|
Thanks @yiyixuxu I'm gonna check it out shortly :) |
This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread. Please note that issues that do not follow the contributing guidelines are likely to be ignored. |
closing this since PR is merged! |
What does this PR do?
For some reason the
control_mode
in the Flux ControlNet pipeline does not take into account the batch size, which results in a downstream error to the call to controlnet due to a shape mismatch. (This is assuming thatcontrol_mode
is an int, but the docstring implies that it's a valid configuration, i.e.Optional[Union[int, List[int]]] = None
.)This is because of the following:
diffusers/src/diffusers/pipelines/flux/pipeline_flux_controlnet.py
Lines 750 to 753 in f28a8c2
i.e. no matter what the effective batch size is, when
control_mode
is passed as an int it will always be of the shape(1,1)
. Therefore, my code does the following:control_mode
into a torch long tensor and reshape it to have an extra axis.control_mode
was already a list to begin with, it will now be of shape(n, 1)
. (And we assumen == control_image.shape[0]
.)control_mode
on the batch axis to go from(1,1)
to obtain the final shape(control_image.shape[0],1)
.I haven't had the time to test
MultiControlNet
this morning but it also needs similar logic. And I'm also confused about this part of the code:diffusers/src/diffusers/pipelines/flux/pipeline_flux_controlnet.py
Lines 789 to 795 in f28a8c2
If
control_mode
is a list of ints and None's then the None's will get converted to-1
(which is necessary since we can't have a torch tensor with mixed dtypes). But as far as I can tell you also can't index into the necessary embedding with-1
since that will give a cuda index assertion error. So I don't know why it's coded up this way in the first place, shouldn't we just throw an exception ifcontrol_mode
(as a list) has Nones or non-ints?The PR is not complete but it's a start, it would be good to improve and understand how exactly we want to do the conditioning here.
Before submitting
documentation guidelines, and
here are tips on formatting docstrings.
Who can review?
@haofanwang @wangqixun possibly?