KEMBAR78
Fix gradients synchronization for multi-GPUs training by Isotr0py · Pull Request #989 · kohya-ss/sd-scripts · GitHub
Skip to content

Conversation

@Isotr0py
Copy link
Contributor

@Isotr0py Isotr0py commented Dec 6, 2023

  • Related issue: About GPU utilization #965

  • Fix [Bug] Gradients not synchronized  #924

  • Remove train_util.transform_if_model_is_DDP, this caused the gradients sync bug on DDP. Now model will be unwrapped by accelerator.unwrap_model() manually.

  • Clean accelerator.prepare() codes in train_network.py

  • Sync lora network gradients in train_network.py manually.

For train_network.py, since we apply the lora through replacing the forward() method instead of the module, the DDP lora won't sync gradients automatically when calling accelerator.backward(loss). Hence we need to use accelerator.reduce() to sync grad manually.

@Isotr0py Isotr0py marked this pull request as draft December 6, 2023 08:47
@Isotr0py Isotr0py marked this pull request as ready for review December 6, 2023 08:53
@Isotr0py
Copy link
Contributor Author

Isotr0py commented Dec 6, 2023

Gradients and parameters after all_reduced should be synchronized now.
QQ截图20231206165243

@Isotr0py Isotr0py changed the title Fix gradients synchronization Fix gradients synchronization for multi-GPUs training Dec 6, 2023
@kohya-ss
Copy link
Owner

kohya-ss commented Dec 7, 2023

Thank you for this! I didn't test multi-gpu training, but the PR seems to be very important.

@deepdelirious
Copy link

This seems to have broken sampling -

Traceback (most recent call last):
  File "/workspace/train.py", line 44, in <module>
    train(args)
  File "/workspace/sdxl_train.py", line 610, in train
    sdxl_train_util.sample_images(
  File "/workspace/library/sdxl_train_util.py", line 367, in sample_images
    return train_util.sample_images_common(SdxlStableDiffusionLongPromptWeightingPipeline, *args, **kwargs)
  File "/workspace/library/train_util.py", line 4788, in sample_images_common
    latents = pipeline(
  File "/opt/conda/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 115, in decorate_context
    return func(*args, **kwargs)
  File "/workspace/library/sdxl_lpw_stable_diffusion.py", line 926, in __call__
    dtype = self.unet.dtype
  File "/opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1695, in __getattr__
    raise AttributeError(f"'{type(self).__name__}' object has no attribute '{name}'")
AttributeError: 'DistributedDataParallel' object has no attribute 'dtype'

@kohya-ss
Copy link
Owner

@deepdelirious I've updated the dev branch. I didn't test with multiple GPUs, but I think it will fix the sampling.

@FurkanGozukara
Copy link

I dont see these in code where ?

ddp_bucket_view
ddp_gradient_as_bucket_view

nana0304 pushed a commit to nana0304/sd-scripts that referenced this pull request Jun 4, 2025
* delete DDP wrapper

* fix train_db vae and train_network

* fix train_db vae and train_network unwrap

* network grad sync

---------

Co-authored-by: Kohya S <52813779+kohya-ss@users.noreply.github.com>
nana0304 pushed a commit to nana0304/sd-scripts that referenced this pull request Jun 4, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants