KEMBAR78
Fix PyTorch 2.3.1 compatibility: add version guard for torch.library.… by Aishwarya0811 · Pull Request #12206 · huggingface/diffusers · GitHub
Skip to content

Conversation

Aishwarya0811
Copy link
Contributor

@Aishwarya0811 Aishwarya0811 commented Aug 21, 2025

What does this PR do?

Fixes #12195 by adding version guards for torch.library.custom_op and torch.library.register_fake which are not available in PyTorch 2.3.1.

Problem

  • torch.library.custom_op and torch.library.register_fake were introduced in PyTorch 2.4
  • Users with PyTorch 2.3.1 get AttributeError: module 'torch.library' has no attribute 'custom_op'
  • This breaks from diffusers import AutoencoderKL and other imports

Solution

  • Added hasattr() checks before using these torch.library functions
  • Functions are only registered when available in the PyTorch version
  • Maintains backward compatibility with PyTorch 2.3.1 without breaking newer versions

Testing

  • ✅ Tested with PyTorch 2.3.1 - import now works without errors
  • ✅ Verified AutoencoderKL imports and functions correctly
  • ✅ No regressions expected with newer PyTorch versions

Fixes #12195

Who can review?

@sayakpaul @yiyixuxu - This is a PyTorch compatibility fix for core library functionality.

…custom_op

- Add hasattr() check for torch.library.custom_op and register_fake
- These functions were added in PyTorch 2.4, causing import failures in 2.3.1
- Both decorators and functions are now properly guarded with version checks
- Maintains backward compatibility while preserving functionality

Fixes huggingface#12195
@sayakpaul sayakpaul requested a review from a-r-r-o-w August 21, 2025 09:54
Copy link
Contributor

@a-r-r-o-w a-r-r-o-w left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi, thanks for the fix! I think we can improve a bit here by using the approach followed in #11941. Specifically, the following lines:

if torch.__version__ >= "2.4.0":
_custom_op = torch.library.custom_op
_register_fake = torch.library.register_fake
else:
def _custom_op_no_op(name, fn=None, /, *, mutates_args, device_types=None, schema=None):
def wrap(func):
return func
return wrap if fn is None else fn
def _register_fake_no_op(op, fn=None, /, *, lib=None, _stacklevel=1):
def wrap(func):
return func
return wrap if fn is None else fn
_custom_op = _custom_op_no_op
_register_fake = _register_fake_no_op

- Replace hasattr check with version string comparison
- Add no-op decorator functions for PyTorch < 2.4.0
- Follows pattern from huggingface#11941 as suggested by reviewer
- Maintains cleaner code structure without indentation changes
@Aishwarya0811
Copy link
Contributor Author

Hi @a-r-r-o-w
Thanks for the feedback! I've updated the PR to use the dummy decorator approach as suggested, following the pattern from #11941. The code is much cleaner now without the indentation changes.

Tested with PyTorch 2.3.1 and the import works correctly.

Copy link
Contributor

@a-r-r-o-w a-r-r-o-w left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the updates! Just one more comment

Aishwarya0811 and others added 6 commits August 23, 2025 00:24
Update all the decorator usages

Co-authored-by: Aryan <contact.aryanvs@gmail.com>
Co-authored-by: Aryan <contact.aryanvs@gmail.com>
Co-authored-by: Aryan <contact.aryanvs@gmail.com>
Co-authored-by: Aryan <contact.aryanvs@gmail.com>
@Aishwarya0811
Copy link
Contributor Author

Hi @a-r-r-o-w I've addressed all your feedback - moved the version check to the top of the file and used private naming with underscores.

Copy link
Contributor

@a-r-r-o-w a-r-r-o-w left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for looking into it @Aishwarya0811!

@a-r-r-o-w
Copy link
Contributor

@bot /style

@github-actions
Copy link
Contributor

github-actions bot commented Aug 23, 2025

Style bot fixed some files and pushed the changes.

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@a-r-r-o-w a-r-r-o-w merged commit 9a7ae77 into huggingface:main Aug 23, 2025
11 checks passed
sayakpaul pushed a commit that referenced this pull request Oct 15, 2025
#12206)

* Fix PyTorch 2.3.1 compatibility: add version guard for torch.library.custom_op

- Add hasattr() check for torch.library.custom_op and register_fake
- These functions were added in PyTorch 2.4, causing import failures in 2.3.1
- Both decorators and functions are now properly guarded with version checks
- Maintains backward compatibility while preserving functionality

Fixes #12195

* Use dummy decorators approach for PyTorch version compatibility

- Replace hasattr check with version string comparison
- Add no-op decorator functions for PyTorch < 2.4.0
- Follows pattern from #11941 as suggested by reviewer
- Maintains cleaner code structure without indentation changes

* Update src/diffusers/models/attention_dispatch.py

Update all the decorator usages

Co-authored-by: Aryan <contact.aryanvs@gmail.com>

* Update src/diffusers/models/attention_dispatch.py

Co-authored-by: Aryan <contact.aryanvs@gmail.com>

* Update src/diffusers/models/attention_dispatch.py

Co-authored-by: Aryan <contact.aryanvs@gmail.com>

* Update src/diffusers/models/attention_dispatch.py

Co-authored-by: Aryan <contact.aryanvs@gmail.com>

* Move version check to top of file and use private naming as requested

* Apply style fixes

---------

Co-authored-by: Aryan <contact.aryanvs@gmail.com>
Co-authored-by: Aryan <aryan@huggingface.co>
Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Version 0.35.0 AutoencoderKL failed with PyTorch 2.3.1: AutoAttributeError: module 'torch.library' has no attribute 'custom_op'

3 participants