KEMBAR78
Unable to extend BaseChatModel with pydantic 2 · Issue #12358 · langchain-ai/langchain · GitHub
Skip to content

Unable to extend BaseChatModel with pydantic 2 #12358

@nikokaoja

Description

@nikokaoja

System Info

langchain = 0.0.324
pydantic = 2.4.2
python = 3.11
platform = macos

In the project I have class MyChat which inherits BaseChatModel and further customize/extends it. Adding additional fields works fine with both pydantic v1 and v2, however adding a field and/or model validator fails with pydantic v2, raising TypeError:

TypeError: cannot pickle 'classmethod' object

Full error:

TypeError                                 Traceback (most recent call last)
.../test.ipynb Cell 1 line 1
      9 from pydantic import ConfigDict, Extra, Field, field_validator, model_validator
     11 logger = logging.getLogger(__name__)
---> 13 class MyChat(BaseChatModel):
     14     client: Any  #: :meta private:
     15     model_name: str = Field(default="gpt-35-turbo", alias="model")

File .../.venv/lib/python3.11/site-packages/pydantic/v1/main.py:221, in ModelMetaclass.__new__(mcs, name, bases, namespace, **kwargs)
    219 elif is_valid_field(var_name) and var_name not in annotations and can_be_changed:
    220     validate_field_name(bases, var_name)
--> 221     inferred = ModelField.infer(
    222         name=var_name,
    223         value=value,
    224         annotation=annotations.get(var_name, Undefined),
    225         class_validators=vg.get_validators(var_name),
    226         config=config,
    227     )
    228     if var_name in fields:
    229         if lenient_issubclass(inferred.type_, fields[var_name].type_):

File .../.venv/lib/python3.11/site-packages/pydantic/v1/fields.py:506, in ModelField.infer(cls, name, value, annotation, class_validators, config)
    503     required = False
    504 annotation = get_annotation_from_field_info(annotation, field_info, name, config.validate_assignment)
--> 506 return cls(
    507     name=name,
    508     type_=annotation,
    509     alias=field_info.alias,
    510     class_validators=class_validators,
    511     default=value,
    512     default_factory=field_info.default_factory,
    513     required=required,
    514     model_config=config,
    515     field_info=field_info,
    516 )

File .../.venv/lib/python3.11/site-packages/pydantic/v1/fields.py:436, in ModelField.__init__(self, name, type_, class_validators, model_config, default, default_factory, required, final, alias, field_info)
    434 self.shape: int = SHAPE_SINGLETON
    435 self.model_config.prepare_field(self)
--> 436 self.prepare()

File .../.venv/lib/python3.11/site-packages/pydantic/v1/fields.py:546, in ModelField.prepare(self)
    539 def prepare(self) -> None:
    540     """
    541     Prepare the field but inspecting self.default, self.type_ etc.
    542 
    543     Note: this method is **not** idempotent (because _type_analysis is not idempotent),
    544     e.g. calling it it multiple times may modify the field and configure it incorrectly.
    545     """
--> 546     self._set_default_and_type()
    547     if self.type_.__class__ is ForwardRef or self.type_.__class__ is DeferredType:
    548         # self.type_ is currently a ForwardRef and there's nothing we can do now,
    549         # user will need to call model.update_forward_refs()
    550         return
...
--> 161     rv = reductor(4)
    162 else:
    163     reductor = getattr(x, "__reduce__", None)

TypeError: cannot pickle 'classmethod' object

Who can help?

No response

Information

  • The official example notebooks/scripts
  • My own modified scripts

Related Components

  • LLMs/Chat Models
  • Embedding Models
  • Prompts / Prompt Templates / Prompt Selectors
  • Output Parsers
  • Document Loaders
  • Vector Stores / Retrievers
  • Memory
  • Agents / Agent Executors
  • Tools / Toolkits
  • Chains
  • Callbacks/Tracing
  • Async

Reproduction

  1. Install pydantic v2
  2. Install langchain
  3. Try running following in for example notebook:
import logging
from typing import ( Any,ClassVar,)
from langchain.chat_models.base import BaseChatModel
from pydantic import ConfigDict, Extra, Field, field_validator, model_validator

logger = logging.getLogger(__name__)

class MyChat(BaseChatModel):
    client: Any  #: :meta private:
    model_name: str = Field(default="gpt-35-turbo", alias="model")
    temperature: float = 0.0
    model_kwargs: dict[str, Any] = Field(default_factory=dict)
    request_timeout: float | tuple[float, float] | None = None
    max_retries: int = 6
    max_tokens: int | None = 2024
    gpu: bool = False

    model_config: ClassVar[ConfigDict] = ConfigDict(
        populate_by_name=True, strict=False, extra="ignore"
    )


    @property
    def _llm_type(self) -> str:
        return "my-chat"
    
    @field_validator("max_tokens", mode="before")
    def check_max_tokens(cls, v: int, values: dict[str, Any]) -> int:
        """Validate max_tokens."""
        if v is not None and v < 1:
            raise ValueError("max_tokens must be greater than 0.")
        return v

Expected behavior

There should be no error.

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugRelated to a bug, vulnerability, unexpected error with an existing feature

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions