KEMBAR78
SD3ControlNetModel forward function error · Issue #9496 · huggingface/diffusers · GitHub
Skip to content

SD3ControlNetModel forward function error #9496

@pibbo88

Description

@pibbo88

Describe the bug

The following codes are line 326 - line 352 in diffusers/models/controlnet_sd3.py. "hidden_states" returned by "torch.utils.checkpoint.checkpoint" in if branch is a tuple, while "hidden_states" returned by "block" in else branch is a tensor. The following layers require a tensor. So when training and using gradient_checkpointing, the training program will raise errors.

for block in self.transformer_blocks:
            if self.training and self.gradient_checkpointing:

                def create_custom_forward(module, return_dict=None):
                    def custom_forward(*inputs):
                        if return_dict is not None:
                            return module(*inputs, return_dict=return_dict)
                        else:
                            return module(*inputs)

                    return custom_forward

                ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
                hidden_states = torch.utils.checkpoint.checkpoint(
                    create_custom_forward(block),
                    hidden_states,
                    encoder_hidden_states,
                    temb,
                    **ckpt_kwargs,
                )

            else:
                encoder_hidden_states, hidden_states = block(
                    hidden_states=hidden_states, encoder_hidden_states=encoder_hidden_states, temb=temb
                )

            block_res_samples = block_res_samples + (hidden_states,)
`

### Reproduction

use "examples/controlnet/train_controlnet_sd3.py" and set "--gradient_checkpointing".

### Logs

```shell
Traceback (most recent call last):
  File "/workspace/pge/tools/miniconda3/envs/sdtrain/lib/python3.10/runpy.py", line 196, in _run_module_as_main
    return _run_code(code, main_globals, None,
  File "/workspace/pge/tools/miniconda3/envs/sdtrain/lib/python3.10/runpy.py", line 86, in _run_code
    exec(code, run_globals)
  File "/root/.vscode-server/extensions/ms-python.debugpy-2024.10.0-linux-x64/bundled/libs/debugpy/adapter/../../debugpy/launcher/../../debugpy/__main__.py", line 39, in <module>
    cli.main()
  File "/root/.vscode-server/extensions/ms-python.debugpy-2024.10.0-linux-x64/bundled/libs/debugpy/adapter/../../debugpy/launcher/../../debugpy/../debugpy/server/cli.py", line 430, in main
    run()
  File "/root/.vscode-server/extensions/ms-python.debugpy-2024.10.0-linux-x64/bundled/libs/debugpy/adapter/../../debugpy/launcher/../../debugpy/../debugpy/server/cli.py", line 284, in run_file
    runpy.run_path(target, run_name="__main__")
  File "/root/.vscode-server/extensions/ms-python.debugpy-2024.10.0-linux-x64/bundled/libs/debugpy/_vendored/pydevd/_pydevd_bundle/pydevd_runpy.py", line 321, in run_path
    return _run_module_code(code, init_globals, run_name,
  File "/root/.vscode-server/extensions/ms-python.debugpy-2024.10.0-linux-x64/bundled/libs/debugpy/_vendored/pydevd/_pydevd_bundle/pydevd_runpy.py", line 135, in _run_module_code
    _run_code(code, mod_globals, init_globals,
  File "/root/.vscode-server/extensions/ms-python.debugpy-2024.10.0-linux-x64/bundled/libs/debugpy/_vendored/pydevd/_pydevd_bundle/pydevd_runpy.py", line 124, in _run_code
    exec(code, run_globals)
  File "/workspace/pge/projects/sd_trainer/train_sd3_controlnet.py", line 1099, in <module>
    main(args)
  File "/workspace/pge/projects/sd_trainer/train_sd3_controlnet.py", line 975, in main
    control_block_res_samples = controlnet(
  File "/workspace/pge/tools/miniconda3/envs/sdtrain/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/workspace/pge/tools/miniconda3/envs/sdtrain/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl
    return forward_call(*args, **kwargs)
  File "/workspace/pge/tools/miniconda3/envs/sdtrain/lib/python3.10/site-packages/accelerate/utils/operations.py", line 820, in forward
    return model_forward(*args, **kwargs)
  File "/workspace/pge/tools/miniconda3/envs/sdtrain/lib/python3.10/site-packages/accelerate/utils/operations.py", line 808, in __call__
    return convert_to_fp32(self.model_forward(*args, **kwargs))
  File "/workspace/pge/tools/miniconda3/envs/sdtrain/lib/python3.10/site-packages/torch/amp/autocast_mode.py", line 43, in decorate_autocast
    return func(*args, **kwargs)
  File "/workspace/pge/projects/sd_trainer/diffusers/src/diffusers/models/controlnet_sd3.py", line 339, in forward
    hidden_states = torch.utils.checkpoint.checkpoint(
  File "/workspace/pge/tools/miniconda3/envs/sdtrain/lib/python3.10/site-packages/torch/_compile.py", line 31, in inner
    return disable_fn(*args, **kwargs)
  File "/workspace/pge/tools/miniconda3/envs/sdtrain/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py", line 600, in _fn
    return fn(*args, **kwargs)
  File "/workspace/pge/tools/miniconda3/envs/sdtrain/lib/python3.10/site-packages/torch/utils/checkpoint.py", line 488, in checkpoint
    ret = function(*args, **kwargs)
  File "/workspace/pge/projects/sd_trainer/diffusers/src/diffusers/models/controlnet_sd3.py", line 334, in custom_forward
    return module(*inputs)
  File "/workspace/pge/tools/miniconda3/envs/sdtrain/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/workspace/pge/tools/miniconda3/envs/sdtrain/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl
    return forward_call(*args, **kwargs)
  File "/workspace/pge/projects/sd_trainer/diffusers/src/diffusers/models/attention.py", line 162, in forward
    norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(hidden_states, emb=temb)
  File "/workspace/pge/tools/miniconda3/envs/sdtrain/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/workspace/pge/tools/miniconda3/envs/sdtrain/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl
    return forward_call(*args, **kwargs)
  File "/workspace/pge/projects/sd_trainer/diffusers/src/diffusers/models/normalization.py", line 139, in forward
    x = self.norm(x) * (1 + scale_msa[:, None]) + shift_msa[:, None]
  File "/workspace/pge/tools/miniconda3/envs/sdtrain/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/workspace/pge/tools/miniconda3/envs/sdtrain/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl
    return forward_call(*args, **kwargs)
  File "/workspace/pge/tools/miniconda3/envs/sdtrain/lib/python3.10/site-packages/torch/nn/modules/normalization.py", line 202, in forward
    return F.layer_norm(
  File "/workspace/pge/tools/miniconda3/envs/sdtrain/lib/python3.10/site-packages/torch/nn/functional.py", line 2576, in layer_norm
    return torch.layer_norm(input, normalized_shape, weight, bias, eps, torch.backends.cudnn.enabled)
TypeError: layer_norm(): argument 'input' (position 1) must be Tensor, not tuple

System Info

  • 🤗 Diffusers version: 0.31.0.dev0
  • Platform: Linux-5.15.0-105-generic-x86_64-with-glibc2.39
  • Running on Google Colab?: No
  • Python version: 3.10.14
  • PyTorch version (GPU?): 2.4.1+cu121 (True)
  • Flax version (CPU?/GPU?/TPU?): not installed (NA)
  • Jax version: not installed
  • JaxLib version: not installed
  • Huggingface_hub version: 0.25.0
  • Transformers version: 4.44.2
  • Accelerate version: 0.34.2
  • PEFT version: not installed
  • Bitsandbytes version: 0.43.3
  • Safetensors version: 0.4.5
  • xFormers version: 0.0.28.post1
  • Accelerator: NVIDIA A40, 46068 MiB
    NVIDIA A40, 46068 MiB
  • Using GPU in script?:
  • Using distributed or parallel set-up in script?:

Who can help?

@yiyixuxu @sayakpaul @DN6 @aso

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't working

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions