-
Notifications
You must be signed in to change notification settings - Fork 6.4k
Closed
Labels
bugSomething isn't workingSomething isn't working
Description
Describe the bug
The following codes are line 326 - line 352 in diffusers/models/controlnet_sd3.py. "hidden_states" returned by "torch.utils.checkpoint.checkpoint" in if branch is a tuple, while "hidden_states" returned by "block" in else branch is a tensor. The following layers require a tensor. So when training and using gradient_checkpointing, the training program will raise errors.
for block in self.transformer_blocks:
if self.training and self.gradient_checkpointing:
def create_custom_forward(module, return_dict=None):
def custom_forward(*inputs):
if return_dict is not None:
return module(*inputs, return_dict=return_dict)
else:
return module(*inputs)
return custom_forward
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(block),
hidden_states,
encoder_hidden_states,
temb,
**ckpt_kwargs,
)
else:
encoder_hidden_states, hidden_states = block(
hidden_states=hidden_states, encoder_hidden_states=encoder_hidden_states, temb=temb
)
block_res_samples = block_res_samples + (hidden_states,)
`
### Reproduction
use "examples/controlnet/train_controlnet_sd3.py" and set "--gradient_checkpointing".
### Logs
```shell
Traceback (most recent call last):
File "/workspace/pge/tools/miniconda3/envs/sdtrain/lib/python3.10/runpy.py", line 196, in _run_module_as_main
return _run_code(code, main_globals, None,
File "/workspace/pge/tools/miniconda3/envs/sdtrain/lib/python3.10/runpy.py", line 86, in _run_code
exec(code, run_globals)
File "/root/.vscode-server/extensions/ms-python.debugpy-2024.10.0-linux-x64/bundled/libs/debugpy/adapter/../../debugpy/launcher/../../debugpy/__main__.py", line 39, in <module>
cli.main()
File "/root/.vscode-server/extensions/ms-python.debugpy-2024.10.0-linux-x64/bundled/libs/debugpy/adapter/../../debugpy/launcher/../../debugpy/../debugpy/server/cli.py", line 430, in main
run()
File "/root/.vscode-server/extensions/ms-python.debugpy-2024.10.0-linux-x64/bundled/libs/debugpy/adapter/../../debugpy/launcher/../../debugpy/../debugpy/server/cli.py", line 284, in run_file
runpy.run_path(target, run_name="__main__")
File "/root/.vscode-server/extensions/ms-python.debugpy-2024.10.0-linux-x64/bundled/libs/debugpy/_vendored/pydevd/_pydevd_bundle/pydevd_runpy.py", line 321, in run_path
return _run_module_code(code, init_globals, run_name,
File "/root/.vscode-server/extensions/ms-python.debugpy-2024.10.0-linux-x64/bundled/libs/debugpy/_vendored/pydevd/_pydevd_bundle/pydevd_runpy.py", line 135, in _run_module_code
_run_code(code, mod_globals, init_globals,
File "/root/.vscode-server/extensions/ms-python.debugpy-2024.10.0-linux-x64/bundled/libs/debugpy/_vendored/pydevd/_pydevd_bundle/pydevd_runpy.py", line 124, in _run_code
exec(code, run_globals)
File "/workspace/pge/projects/sd_trainer/train_sd3_controlnet.py", line 1099, in <module>
main(args)
File "/workspace/pge/projects/sd_trainer/train_sd3_controlnet.py", line 975, in main
control_block_res_samples = controlnet(
File "/workspace/pge/tools/miniconda3/envs/sdtrain/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/workspace/pge/tools/miniconda3/envs/sdtrain/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl
return forward_call(*args, **kwargs)
File "/workspace/pge/tools/miniconda3/envs/sdtrain/lib/python3.10/site-packages/accelerate/utils/operations.py", line 820, in forward
return model_forward(*args, **kwargs)
File "/workspace/pge/tools/miniconda3/envs/sdtrain/lib/python3.10/site-packages/accelerate/utils/operations.py", line 808, in __call__
return convert_to_fp32(self.model_forward(*args, **kwargs))
File "/workspace/pge/tools/miniconda3/envs/sdtrain/lib/python3.10/site-packages/torch/amp/autocast_mode.py", line 43, in decorate_autocast
return func(*args, **kwargs)
File "/workspace/pge/projects/sd_trainer/diffusers/src/diffusers/models/controlnet_sd3.py", line 339, in forward
hidden_states = torch.utils.checkpoint.checkpoint(
File "/workspace/pge/tools/miniconda3/envs/sdtrain/lib/python3.10/site-packages/torch/_compile.py", line 31, in inner
return disable_fn(*args, **kwargs)
File "/workspace/pge/tools/miniconda3/envs/sdtrain/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py", line 600, in _fn
return fn(*args, **kwargs)
File "/workspace/pge/tools/miniconda3/envs/sdtrain/lib/python3.10/site-packages/torch/utils/checkpoint.py", line 488, in checkpoint
ret = function(*args, **kwargs)
File "/workspace/pge/projects/sd_trainer/diffusers/src/diffusers/models/controlnet_sd3.py", line 334, in custom_forward
return module(*inputs)
File "/workspace/pge/tools/miniconda3/envs/sdtrain/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/workspace/pge/tools/miniconda3/envs/sdtrain/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl
return forward_call(*args, **kwargs)
File "/workspace/pge/projects/sd_trainer/diffusers/src/diffusers/models/attention.py", line 162, in forward
norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(hidden_states, emb=temb)
File "/workspace/pge/tools/miniconda3/envs/sdtrain/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/workspace/pge/tools/miniconda3/envs/sdtrain/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl
return forward_call(*args, **kwargs)
File "/workspace/pge/projects/sd_trainer/diffusers/src/diffusers/models/normalization.py", line 139, in forward
x = self.norm(x) * (1 + scale_msa[:, None]) + shift_msa[:, None]
File "/workspace/pge/tools/miniconda3/envs/sdtrain/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/workspace/pge/tools/miniconda3/envs/sdtrain/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl
return forward_call(*args, **kwargs)
File "/workspace/pge/tools/miniconda3/envs/sdtrain/lib/python3.10/site-packages/torch/nn/modules/normalization.py", line 202, in forward
return F.layer_norm(
File "/workspace/pge/tools/miniconda3/envs/sdtrain/lib/python3.10/site-packages/torch/nn/functional.py", line 2576, in layer_norm
return torch.layer_norm(input, normalized_shape, weight, bias, eps, torch.backends.cudnn.enabled)
TypeError: layer_norm(): argument 'input' (position 1) must be Tensor, not tuple
System Info
- 🤗 Diffusers version: 0.31.0.dev0
- Platform: Linux-5.15.0-105-generic-x86_64-with-glibc2.39
- Running on Google Colab?: No
- Python version: 3.10.14
- PyTorch version (GPU?): 2.4.1+cu121 (True)
- Flax version (CPU?/GPU?/TPU?): not installed (NA)
- Jax version: not installed
- JaxLib version: not installed
- Huggingface_hub version: 0.25.0
- Transformers version: 4.44.2
- Accelerate version: 0.34.2
- PEFT version: not installed
- Bitsandbytes version: 0.43.3
- Safetensors version: 0.4.5
- xFormers version: 0.0.28.post1
- Accelerator: NVIDIA A40, 46068 MiB
NVIDIA A40, 46068 MiB - Using GPU in script?:
- Using distributed or parallel set-up in script?:
Who can help?
Metadata
Metadata
Assignees
Labels
bugSomething isn't workingSomething isn't working