diff --git a/polyfactory/factories/attrs_factory.py b/polyfactory/factories/attrs_factory.py index dfc77f06..31024088 100644 --- a/polyfactory/factories/attrs_factory.py +++ b/polyfactory/factories/attrs_factory.py @@ -4,7 +4,7 @@ from typing import TYPE_CHECKING, Generic, TypeVar from polyfactory.exceptions import MissingDependencyException -from polyfactory.factories.base import BaseFactory +from polyfactory.factories.base import BaseFactory, cache_model_fields from polyfactory.field_meta import FieldMeta, Null if TYPE_CHECKING: @@ -35,13 +35,19 @@ def is_supported_type(cls, value: Any) -> TypeGuard[type[T]]: return isclass(value) and hasattr(value, "__attrs_attrs__") @classmethod + def _init_model(cls) -> None: + """Initialize the model and resolve type annotations.""" + super()._init_model() + if hasattr(cls, "__model__"): + cls.resolve_types(cls.__model__) + + @classmethod + @cache_model_fields def get_model_fields(cls) -> list[FieldMeta]: field_metas: list[FieldMeta] = [] none_type = type(None) - cls.resolve_types(cls.__model__) fields = attrs.fields(cls.__model__) - for field in fields: if not field.init: continue diff --git a/polyfactory/factories/base.py b/polyfactory/factories/base.py index c2b71b88..569eec83 100644 --- a/polyfactory/factories/base.py +++ b/polyfactory/factories/base.py @@ -1,6 +1,7 @@ from __future__ import annotations import copy +import functools import inspect from abc import ABC, abstractmethod from collections import Counter, abc, deque @@ -97,6 +98,22 @@ F = TypeVar("F", bound="BaseFactory[Any]") +def cache_model_fields(func: Callable[[type[F]], list["FieldMeta"]]) -> Callable[[type[F]], list["FieldMeta"]]: + """Decorator to cache the results of get_model_fields() to avoid repeated introspection. + + :param func: The get_model_fields classmethod to wrap + :returns: Wrapped function with caching + """ + + @functools.wraps(func) + def wrapper(cls: type[F]) -> list["FieldMeta"]: + if "_fields_metadata" not in cls.__dict__: + cls._fields_metadata = func(cls) + return cls._fields_metadata + + return wrapper + + class BuildContext(TypedDict): seen_models: set[type] @@ -124,12 +141,13 @@ class BaseFactory(ABC, Generic[T]): """A sync persistence handler. Can be a class or a class instance.""" __async_persistence__: type[AsyncPersistenceProtocol[T]] | AsyncPersistenceProtocol[T] | None = None """An async persistence handler. Can be a class or a class instance.""" - __set_as_default_factory_for_type__ = False + + __set_as_default_factory_for_type__: ClassVar[bool] = False """ Flag dictating whether to set as the default factory for the given type. If 'True' the factory will be used instead of dynamically generating a factory for the type. """ - __is_base_factory__: bool = False + __is_base_factory__: ClassVar[bool] = False """ Flag dictating whether the factory is a 'base' factory. Base factories are registered globally as handlers for types. For example, the 'DataclassFactory', 'TypedDictFactory' and 'ModelFactory' are all base factories. diff --git a/polyfactory/factories/dataclass_factory.py b/polyfactory/factories/dataclass_factory.py index fd7e7f08..df37c238 100644 --- a/polyfactory/factories/dataclass_factory.py +++ b/polyfactory/factories/dataclass_factory.py @@ -5,7 +5,7 @@ from typing_extensions import TypeGuard -from polyfactory.factories.base import BaseFactory, T +from polyfactory.factories.base import BaseFactory, T, cache_model_fields from polyfactory.field_meta import FieldMeta, Null @@ -24,6 +24,7 @@ def is_supported_type(cls, value: Any) -> TypeGuard[type[T]]: return bool(is_dataclass(value)) @classmethod + @cache_model_fields def get_model_fields(cls) -> list["FieldMeta"]: """Retrieve a list of fields from the factory's model. diff --git a/polyfactory/factories/msgspec_factory.py b/polyfactory/factories/msgspec_factory.py index fbbee45b..afecb30e 100644 --- a/polyfactory/factories/msgspec_factory.py +++ b/polyfactory/factories/msgspec_factory.py @@ -4,7 +4,7 @@ from typing import TYPE_CHECKING, Any, Callable, Generic, TypeVar, get_type_hints from polyfactory.exceptions import MissingDependencyException -from polyfactory.factories.base import BaseFactory +from polyfactory.factories.base import BaseFactory, cache_model_fields from polyfactory.field_meta import FieldMeta, Null from polyfactory.value_generators.constrained_numbers import handle_constrained_int from polyfactory.value_generators.primitives import create_random_bytes @@ -46,6 +46,7 @@ def is_supported_type(cls, value: Any) -> TypeGuard[type[T]]: return isclass(value) and hasattr(value, "__struct_fields__") @classmethod + @cache_model_fields def get_model_fields(cls) -> list[FieldMeta]: fields_meta: list[FieldMeta] = [] diff --git a/polyfactory/factories/pydantic_factory.py b/polyfactory/factories/pydantic_factory.py index acedc35a..f60d5b35 100644 --- a/polyfactory/factories/pydantic_factory.py +++ b/polyfactory/factories/pydantic_factory.py @@ -12,7 +12,7 @@ from typing_extensions import Literal, get_args from polyfactory.exceptions import MissingDependencyException -from polyfactory.factories.base import BaseFactory, BuildContext +from polyfactory.factories.base import BaseFactory, BuildContext, cache_model_fields from polyfactory.factories.base import BuildContext as BaseBuildContext from polyfactory.field_meta import Constraints, FieldMeta, Null from polyfactory.utils.helpers import unwrap_new_type, unwrap_optional @@ -411,6 +411,7 @@ def is_supported_type(cls, value: Any) -> TypeGuard[type[T]]: return _is_pydantic_v1_model(value) or _is_pydantic_v2_model(value) @classmethod + @cache_model_fields def get_model_fields(cls) -> list["FieldMeta"]: """Retrieve a list of fields from the factory's model. @@ -418,28 +419,25 @@ def get_model_fields(cls) -> list["FieldMeta"]: :returns: A list of field MetaData instances. """ - if "_fields_metadata" not in cls.__dict__: - if _is_pydantic_v1_model(cls.__model__): - cls._fields_metadata = [ - PydanticFieldMeta.from_model_field( - field, - use_alias=not cls.__model__.__config__.allow_population_by_field_name, # type: ignore[attr-defined] - ) - for field in cls.__model__.__fields__.values() - ] - else: - use_alias = cls.__model__.model_config.get("validate_by_name", False) or cls.__model__.model_config.get( - "populate_by_name", False + if _is_pydantic_v1_model(cls.__model__): + return [ + PydanticFieldMeta.from_model_field( + field, + use_alias=not cls.__model__.__config__.allow_population_by_field_name, # type: ignore[attr-defined] ) - cls._fields_metadata = [ - PydanticFieldMeta.from_field_info( - field_info=field_info, - field_name=field_name, - use_alias=not use_alias, - ) - for field_name, field_info in cls.__model__.model_fields.items() # pyright: ignore[reportGeneralTypeIssues] - ] - return cls._fields_metadata + for field in cls.__model__.__fields__.values() + ] + use_alias = cls.__model__.model_config.get("validate_by_name", False) or cls.__model__.model_config.get( + "populate_by_name", False + ) + return [ + PydanticFieldMeta.from_field_info( + field_info=field_info, + field_name=field_name, + use_alias=not use_alias, + ) + for field_name, field_info in cls.__model__.model_fields.items() # pyright: ignore[reportGeneralTypeIssues] + ] @classmethod def get_constrained_field_value( diff --git a/polyfactory/factories/sqlalchemy_factory.py b/polyfactory/factories/sqlalchemy_factory.py index f076c582..aac10778 100644 --- a/polyfactory/factories/sqlalchemy_factory.py +++ b/polyfactory/factories/sqlalchemy_factory.py @@ -18,6 +18,8 @@ from polyfactory.exceptions import ConfigurationException, MissingDependencyException, ParameterException from polyfactory.factories.base import BaseFactory +from polyfactory.exceptions import MissingDependencyException, ParameterException +from polyfactory.factories.base import BaseFactory, cache_model_fields from polyfactory.field_meta import Constraints, FieldMeta from polyfactory.persistence import AsyncPersistenceProtocol, SyncPersistenceProtocol from polyfactory.utils.types import Frozendict @@ -241,6 +243,7 @@ def get_type_from_collection_class( return annotation @classmethod + @cache_model_fields def get_model_fields(cls) -> list[FieldMeta]: fields_meta: list[FieldMeta] = [] diff --git a/polyfactory/factories/typed_dict_factory.py b/polyfactory/factories/typed_dict_factory.py index cecf47c8..940945f0 100644 --- a/polyfactory/factories/typed_dict_factory.py +++ b/polyfactory/factories/typed_dict_factory.py @@ -11,7 +11,7 @@ is_typeddict, ) -from polyfactory.factories.base import BaseFactory +from polyfactory.factories.base import BaseFactory, cache_model_fields from polyfactory.field_meta import FieldMeta, Null TypedDictT = TypeVar("TypedDictT", bound=_TypedDictMeta) @@ -32,6 +32,7 @@ def is_supported_type(cls, value: Any) -> TypeGuard[type[TypedDictT]]: return is_typeddict(value) @classmethod + @cache_model_fields def get_model_fields(cls) -> list["FieldMeta"]: """Retrieve a list of fields from the factory's model. diff --git a/polyfactory/field_meta.py b/polyfactory/field_meta.py index 74081c22..aca3a1e9 100644 --- a/polyfactory/field_meta.py +++ b/polyfactory/field_meta.py @@ -22,7 +22,6 @@ import datetime from collections.abc import Sequence from decimal import Decimal - from random import Random from re import Pattern from typing_extensions import NotRequired, Self @@ -68,10 +67,9 @@ class Constraints(TypedDict): class FieldMeta: """Factory field metadata container. This class is used to store the data about a field of a factory's model.""" - __slots__ = ("__dict__", "annotation", "children", "constraints", "default", "name", "random") + __slots__ = ("__dict__", "annotation", "children", "constraints", "default", "name") annotation: Any - random: Random children: list[FieldMeta] | None default: Any name: str