From 4557260dea2d82d15570363ccf1682e14737570c Mon Sep 17 00:00:00 2001 From: Alex Hlavinka Date: Tue, 25 Feb 2025 21:13:54 -0600 Subject: [PATCH 1/9] feat(fields): implement Param and CallableParam for handling unmapped parameters that can be referenced during build --- polyfactory/exceptions.py | 4 ++ polyfactory/factories/base.py | 44 +++++++++++-- polyfactory/fields.py | 121 +++++++++++++++++++++++++++++++++- 3 files changed, 164 insertions(+), 5 deletions(-) diff --git a/polyfactory/exceptions.py b/polyfactory/exceptions.py index 53f1271a..a4c22b75 100644 --- a/polyfactory/exceptions.py +++ b/polyfactory/exceptions.py @@ -16,3 +16,7 @@ class MissingBuildKwargException(FactoryException): class MissingDependencyException(FactoryException, ImportError): """Missing dependency exception - used when a dependency is not installed""" + + +class MissingParamException(FactoryException): + """Missing parameter exception - used when a required Param is not provided""" diff --git a/polyfactory/factories/base.py b/polyfactory/factories/base.py index a2eaf52d..da828786 100644 --- a/polyfactory/factories/base.py +++ b/polyfactory/factories/base.py @@ -51,9 +51,14 @@ MIN_COLLECTION_LENGTH, RANDOMIZE_COLLECTION_LENGTH, ) -from polyfactory.exceptions import ConfigurationException, MissingBuildKwargException, ParameterException +from polyfactory.exceptions import ( + ConfigurationException, + MissingBuildKwargException, + MissingParamException, + ParameterException, +) from polyfactory.field_meta import Null -from polyfactory.fields import Fixture, Ignore, PostGenerated, Require, Use +from polyfactory.fields import BaseParam, Fixture, Ignore, IsNotPassed, PostGenerated, Require, Use from polyfactory.utils.helpers import ( flatten_annotation, get_collection_type, @@ -965,6 +970,30 @@ def _check_declared_fields_exist_in_model(cls) -> None: if isinstance(field_value, (Use, PostGenerated, Ignore, Require)): raise ConfigurationException(error_message) + @classmethod + def _handle_factory_params(cls, params: dict[str, BaseParam], **kwargs: Any) -> dict[str, Any]: + """Get the factory parameters. + + :param params: A dict of field name to Param instances. + :param kwargs: Any build kwargs. + + :returns: A dict of fieldname mapped to realized Param values. + """ + + try: + return {name: param.to_value(kwargs.get(name, IsNotPassed)) for name, param in params.items()} + except MissingParamException as e: + msg = "Missing required kwargs" + raise MissingBuildKwargException(msg) from e + + @classmethod + def get_factory_params(cls) -> dict[str, BaseParam]: + """Get the factory parameters. + + :returns: A dict of field name to Param instances. + """ + return {name: item for name, item in cls.__dict__.items() if isinstance(item, BaseParam)} + @classmethod def process_kwargs(cls, **kwargs: Any) -> dict[str, Any]: """Process the given kwargs and generate values for the factory's model. @@ -980,6 +1009,9 @@ def process_kwargs(cls, **kwargs: Any) -> dict[str, Any]: result: dict[str, Any] = {**kwargs} generate_post: dict[str, PostGenerated] = {} + params = cls.get_factory_params() + result.update(cls._handle_factory_params(params, **kwargs)) + for field_meta in cls.get_model_fields(): field_build_parameters = cls.extract_field_build_parameters(field_meta=field_meta, build_args=kwargs) if cls.should_set_field_value(field_meta, **kwargs) and not cls.should_use_default_value(field_meta): @@ -1016,7 +1048,7 @@ def process_kwargs(cls, **kwargs: Any) -> dict[str, Any]: for field_name, post_generator in generate_post.items(): result[field_name] = post_generator.to_value(field_name, result) - return result + return {key: value for key, value in result.items() if key not in params} @classmethod def process_kwargs_coverage(cls, **kwargs: Any) -> abc.Iterable[dict[str, Any]]: @@ -1034,6 +1066,9 @@ def process_kwargs_coverage(cls, **kwargs: Any) -> abc.Iterable[dict[str, Any]]: result: dict[str, Any] = {**kwargs} generate_post: dict[str, PostGenerated] = {} + params = cls.get_factory_params() + result.update(cls._handle_factory_params(params, **kwargs)) + for field_meta in cls.get_model_fields(): field_build_parameters = cls.extract_field_build_parameters(field_meta=field_meta, build_args=kwargs) @@ -1069,7 +1104,8 @@ def process_kwargs_coverage(cls, **kwargs: Any) -> abc.Iterable[dict[str, Any]]: for resolved in resolve_kwargs_coverage(result): for field_name, post_generator in generate_post.items(): resolved[field_name] = post_generator.to_value(field_name, resolved) - yield resolved + + yield {key: value for key, value in resolved.items() if key not in params} @classmethod def build(cls, **kwargs: Any) -> T: diff --git a/polyfactory/fields.py b/polyfactory/fields.py index b325330d..7698b984 100644 --- a/polyfactory/fields.py +++ b/polyfactory/fields.py @@ -4,9 +4,10 @@ from typing_extensions import ParamSpec -from polyfactory.exceptions import ParameterException +from polyfactory.exceptions import MissingParamException, ParameterException T = TypeVar("T") +U = TypeVar("U") P = ParamSpec("P") @@ -114,3 +115,121 @@ def to_value(self) -> Any: msg = "fixture has not been registered using the register_factory decorator" raise ParameterException(msg) + + +class NotPassed: + """Indicates a parameter was not passed to a factory field and must be + passed at build time. + """ + + +IsNotPassed = NotPassed() + + +class BaseParam(Generic[T, U]): + """Base class for parameters. + + This class is used to pass a parameters that can be referenced by other + fields but will not be passed to the final object. + + It is generic over the type of the parameter that will be used during build + and also the method used to generate that value (e.g. as a constant or a + callable). + """ + + def to_value(self, from_build: U | NotPassed = IsNotPassed) -> T: + """Determines the value of the parameter. + + This method must be implemented in subclasses. + + :param from_build: The value passed at build time. + :returns: The value + :raises: NotImplementedError + """ + msg = "to_value must be implemented in subclasses" + raise NotImplementedError(msg) # pragma: no cover + + +class Param(Generic[T], BaseParam[T, T]): + """A constant parameter that can be used by other fields but will not be + passed to the final object. + + If a value for the parameter is not passed in the field's definition, it must + be passed at build time. Otherwise, a MissingParamException will be raised. + """ + + __slots__ = ("param",) + + def __init__(self, param: T | NotPassed = IsNotPassed) -> None: + """Designate a parameter. + + :param param: A constant or an unpassed value that can be referenced later + """ + self.param = param + + def to_value(self, from_build: T | NotPassed = IsNotPassed) -> T: + """Determines the value to use at build time + + If a value was passed to the constructor, it will be used. Otherwise, the value + passed at build time will be used. If no value was passed at build time, a + MissingParamException will be raised. + + :param args: from_build: The value passed at build time (if any). + :returns: The value + :raises: MissingParamException + """ + if self.param is IsNotPassed: + if from_build is not IsNotPassed: + return cast(T, from_build) + msg = "Param value was not passed at build time" + raise MissingParamException(msg) + return cast(T, self.param) + + +class CallableParam(Generic[T], BaseParam[T, Callable[..., T]]): + """A callable parameter that can be used by other fields but will not be + passed to the final object. + + The callable may be passed optional keyword arguments via the constructor + of this class. The callable will be invoked with the passed keyword + arguments and any positional arguments passed at build time. + + If a callable for the parameter is not passed in the field's definition, it must + be passed at build time. Otherwise, a MissingParamException will be raised. + """ + + __slots__ = ( + "kwargs", + "param", + ) + + def __init__( + self, + param: Callable[..., T] | NotPassed = IsNotPassed, + **kwargs: Any, + ) -> None: + """Designate field as a callable parameter. + + :param param: A callable that will be evaluated at build time. + :param kwargs: Any kwargs to pass to the callable. + """ + self.param = param + self.kwargs = kwargs + + def to_value(self, from_build: Callable[..., T] | NotPassed = IsNotPassed) -> T: + """Determine the value to use at build time. + + If a value was passed to the constructor, it will be used. Otherwise, the value + passed at build time will be used. If no value was passed at build time, a + MissingParamException will be raised. + + :param args: from_build: The callable passed at build time (if any). + :returns: The value + :raises: MissingParamException + """ + if self.param is IsNotPassed: + if from_build is not IsNotPassed: + return cast(Callable[..., T], from_build)(**self.kwargs) + msg = "Param value was not passed at build time" + raise MissingParamException(msg) + return cast(Callable[..., T], self.param)(**self.kwargs) From 0921a255edaa225d35a46e7918996ee9e3abc9e9 Mon Sep 17 00:00:00 2001 From: Alex Hlavinka Date: Tue, 25 Feb 2025 21:15:47 -0600 Subject: [PATCH 2/9] test(test_factory_fields): implement tests for two new unmapped parameter classes --- tests/test_factory_fields.py | 94 +++++++++++++++++++++++++- tests/test_type_coverage_generation.py | 80 +++++++++++++++++++++- 2 files changed, 172 insertions(+), 2 deletions(-) diff --git a/tests/test_factory_fields.py b/tests/test_factory_fields.py index e07aa9f6..50c93566 100644 --- a/tests/test_factory_fields.py +++ b/tests/test_factory_fields.py @@ -11,7 +11,7 @@ from polyfactory.exceptions import ConfigurationException, MissingBuildKwargException from polyfactory.factories.dataclass_factory import DataclassFactory from polyfactory.factories.pydantic_factory import ModelFactory -from polyfactory.fields import Ignore, PostGenerated, Require, Use +from polyfactory.fields import CallableParam, Ignore, Param, PostGenerated, Require, Use def test_use() -> None: @@ -83,6 +83,98 @@ class MyFactory(ModelFactory): assert MyFactory.build().name is None +def test_param__from_factory() -> None: + value: int = 3 + + class MyModel(BaseModel): + description: str + + class MyFactory(ModelFactory): + __model__ = MyModel + length = Param[int](value) + + @post_generated + @classmethod + def description(cls, length: int) -> str: + return "abcd"[:length] + + result = MyFactory.build() + assert result.description == "abc" + + +def test_param__from_kwargs() -> None: + value: int = 3 + + class MyModel(BaseModel): + description: str + + class MyFactory(ModelFactory): + __model__ = MyModel + length = Param[int]() + + @post_generated + @classmethod + def description(cls, length: int) -> str: + return "abcd"[:length] + + result = MyFactory.build(length=value) + assert result.description == "abc" + + +def test_param__from_kwargs__missing() -> None: + class MyModel(BaseModel): + description: str + + class MyFactory(ModelFactory): + __model__ = MyModel + length = Param[int]() + + @post_generated + @classmethod + def description(cls, length: int) -> str: + return "abcd"[:length] + + with pytest.raises(MissingBuildKwargException): + MyFactory.build() + + +def test_callable_param__from_factory() -> None: + class MyModel(BaseModel): + description: str + + class MyFactory(ModelFactory): + __model__ = MyModel + length = CallableParam(lambda value: value, value=3) + + @post_generated + @classmethod + def description(cls, length: int) -> str: + return "abcd"[:length] + + result = MyFactory.build() + assert result.description == "abc" + + +def test_callable_param__from_kwargs() -> None: + value1: int = 2 + value2: int = 1 + + class MyModel(BaseModel): + description: str + + class MyFactory(ModelFactory): + __model__ = MyModel + length = CallableParam[int](value1=value1, value2=value2) + + @post_generated + @classmethod + def description(cls, length: int) -> str: + return "abcd"[:length] + + result = MyFactory.build(length=lambda value1, value2: value1 + value2) + assert result.description == "abcd"[: value1 + value2] + + def test_post_generation() -> None: random_delta = timedelta(days=random.randint(0, 12), seconds=random.randint(13, 13000)) diff --git a/tests/test_type_coverage_generation.py b/tests/test_type_coverage_generation.py index 723ddc90..a67e3320 100644 --- a/tests/test_type_coverage_generation.py +++ b/tests/test_type_coverage_generation.py @@ -12,10 +12,11 @@ from pydantic import BaseModel from polyfactory.decorators import post_generated -from polyfactory.exceptions import ParameterException +from polyfactory.exceptions import MissingBuildKwargException, ParameterException from polyfactory.factories.dataclass_factory import DataclassFactory from polyfactory.factories.pydantic_factory import ModelFactory from polyfactory.factories.typed_dict_factory import TypedDictFactory +from polyfactory.fields import Param from polyfactory.utils.types import NoneType from tests.test_pydantic_factory import IS_PYDANTIC_V1 @@ -39,6 +40,83 @@ class ProfileFactory(DataclassFactory[Profile]): assert isinstance(result, Profile) +def test_coverage_param__from_factory() -> None: + value = "Demo" + + @dataclass + class Profile: + name: str + high_score: Union[int, float] + dob: date + data: Union[str, date, int, float] + + class ProfileFactory(DataclassFactory[Profile]): + __model__ = Profile + last_name = Param[str](value) + + @post_generated + @classmethod + def name(cls, last_name: str) -> str: + return f"The {last_name}" + + results = list(ProfileFactory.coverage()) + + assert len(results) == 4 + + for result in results: + assert isinstance(result, Profile) + assert result.name == f"The {value}" + + +def test_coverage_param__from_kwargs() -> None: + value = "Demo" + + @dataclass + class Profile: + name: str + high_score: Union[int, float] + dob: date + data: Union[str, date, int, float] + + class ProfileFactory(DataclassFactory[Profile]): + __model__ = Profile + last_name = Param[str]() + + @post_generated + @classmethod + def name(cls, last_name: str) -> str: + return f"The {last_name}" + + results = list(ProfileFactory.coverage(last_name=value)) + + assert len(results) == 4 + + for result in results: + assert isinstance(result, Profile) + assert result.name == f"The {value}" + + +def test_coverage_param__from_kwargs__missing() -> None: + @dataclass + class Profile: + name: str + high_score: Union[int, float] + dob: date + data: Union[str, date, int, float] + + class ProfileFactory(DataclassFactory[Profile]): + __model__ = Profile + last_name = Param[str]() + + @post_generated + @classmethod + def name(cls, last_name: str) -> str: + return f"The {last_name}" + + with pytest.raises(MissingBuildKwargException): + list(ProfileFactory.coverage()) + + def test_coverage_tuple() -> None: @dataclass class Pair: From cbbcfc3770ad52392bcba5d4f989399f1797fe47 Mon Sep 17 00:00:00 2001 From: Alex Hlavinka Date: Tue, 25 Feb 2025 22:00:13 -0600 Subject: [PATCH 3/9] docs(fields): updated docs to include definitions and examples with new Param and CallableParam field types --- docs/examples/fields/test_example_10.py | 49 +++++++++++++++++++++++ docs/examples/fields/test_example_9.py | 52 +++++++++++++++++++++++++ docs/usage/fields.rst | 30 ++++++++++++++ 3 files changed, 131 insertions(+) create mode 100644 docs/examples/fields/test_example_10.py create mode 100644 docs/examples/fields/test_example_9.py diff --git a/docs/examples/fields/test_example_10.py b/docs/examples/fields/test_example_10.py new file mode 100644 index 00000000..be042b43 --- /dev/null +++ b/docs/examples/fields/test_example_10.py @@ -0,0 +1,49 @@ +from dataclasses import dataclass + +from polyfactory.decorators import post_generated +from polyfactory.factories import DataclassFactory +from polyfactory.fields import CallableParam + + +@dataclass +class Person: + name: str + age_next_year: int + + +class PersonFactoryWithParamValueSpecifiedInFactory(DataclassFactory[Person]): + """In this factory, the next_years_age_from_calculator must be passed at build time.""" + + next_years_age_from_calculator = CallableParam[int](lambda age: age + 1, age=20) + + @post_generated + @classmethod + def age_next_year(cls, next_years_age_from_calculator: int) -> int: + return next_years_age_from_calculator + + +def test_factory__in_factory() -> None: + person = PersonFactoryWithParamValueSpecifiedInFactory.build() + + assert isinstance(person, Person) + assert not hasattr(person, "next_years_age_from_calculator") + assert person.age_next_year == 21 + + +class PersonFactoryWithParamValueSetAtBuild(DataclassFactory[Person]): + """In this factory, the next_years_age_from_calculator must be passed at build time.""" + + next_years_age_from_calculator = CallableParam[int](age=20) + + @post_generated + @classmethod + def age_next_year(cls, next_years_age_from_calculator: int) -> int: + return next_years_age_from_calculator + + +def test_factory__build_time() -> None: + person = PersonFactoryWithParamValueSpecifiedInFactory.build(next_years_age_from_calculator=lambda age: age + 1) + + assert isinstance(person, Person) + assert not hasattr(person, "next_years_age_from_calculator") + assert person.age_next_year == 21 diff --git a/docs/examples/fields/test_example_9.py b/docs/examples/fields/test_example_9.py new file mode 100644 index 00000000..fd7d7480 --- /dev/null +++ b/docs/examples/fields/test_example_9.py @@ -0,0 +1,52 @@ +from dataclasses import dataclass +from typing import List + +from polyfactory.decorators import post_generated +from polyfactory.factories import DataclassFactory +from polyfactory.fields import Param + + +@dataclass +class Pet: + name: str + sound: str + + +class PetFactoryWithParamValueSetAtBuild(DataclassFactory[Pet]): + """In this factory, the name_choices must be passed at build time.""" + + name_choices = Param[List[str]]() + + @post_generated + @classmethod + def name(cls, name_choices: List[str]) -> str: + return cls.__random__.choice(name_choices) + + +def test_factory__build_time() -> None: + names = ["Ralph", "Roxy"] + pet = PetFactoryWithParamValueSetAtBuild.build(name_choices=names) + + assert isinstance(pet, Pet) + assert not hasattr(pet, "name_choices") + assert pet.name in names + + +class PetFactoryWithParamSpecififiedInFactory(DataclassFactory[Pet]): + """In this factory, the name_choices are specified in the + factory and do not need to be passed at build time.""" + + name_choices = Param[List[str]](["Ralph", "Roxy"]) + + @post_generated + @classmethod + def name(cls, name_choices: List[str]) -> str: + return cls.__random__.choice(name_choices) + + +def test_factory__in_factory() -> None: + pet = PetFactoryWithParamSpecififiedInFactory.build() + + assert isinstance(pet, Pet) + assert not hasattr(pet, "name_choices") + assert pet.name in ["Ralph", "Roxy"] diff --git a/docs/usage/fields.rst b/docs/usage/fields.rst index 4764063e..4df063fc 100644 --- a/docs/usage/fields.rst +++ b/docs/usage/fields.rst @@ -78,6 +78,36 @@ The signature for use is: ``cb: Callable, *args, **defaults`` it can receive an callable should be: ``name: str, values: dict[str, Any], *args, **defaults``. The already generated values are mapped by name in the values dictionary. + +The ``Param`` Field +------------------- + +The :class:`Param ` class denotes a constant parameter that can be referenced by other fields at build but is not mapped to the final object. This is useful for passing values needed by other factory fields but that are not part of object being built. + +The Param type can either accept a constant value at the definition of the factory, or its value can be set at build time. + +If neither a value is provided at the definition of the factory nor at build time, an exception will be raised. + +.. literalinclude:: /examples/fields/test_example_9.py + :caption: Using the ``Param`` field + :language: python + + +The ``CallableParam`` Field +--------------------------- + +The :class:`CallableParam ` class denotes a callable parameter with a return value that may be referenced by other fields during build but is not mapped to the final object. Optional keyword arguments may be passed to the callable as part of the field definition on the factory. Any additional keyword arguments passed to the build method will also not be passed to the final object. + +The CallableParam type can either accept a callable provided at the definition of the factory, or its value can be passed at build time. The callable is executed at the beginning of build. + +If neither a value is provided at the definition of the factory nor at build time, an exception will be raised. + +The difference between a Param and a CallableParam is that the CallableParam is always executed at build time. If you need to pass an unmapped callable to the factory that should not automatically be executed at build time, use a Param. + +.. literalinclude:: /examples/fields/test_example_10.py + :caption: Using the ``CallableParam`` field + :language: python + Factories as Fields --------------------------- From 1af8f75f741899973b18814e0b927516d9b12435 Mon Sep 17 00:00:00 2001 From: Alex Hlavinka Date: Sun, 23 Mar 2025 13:32:27 -0500 Subject: [PATCH 4/9] feat(Param): consolidate Param API into a single class for simplified development and documentation --- polyfactory/factories/base.py | 27 +++++-- polyfactory/fields.py | 143 +++++++++++++--------------------- 2 files changed, 73 insertions(+), 97 deletions(-) diff --git a/polyfactory/factories/base.py b/polyfactory/factories/base.py index da828786..af692cde 100644 --- a/polyfactory/factories/base.py +++ b/polyfactory/factories/base.py @@ -8,7 +8,7 @@ from datetime import date, datetime, time, timedelta from decimal import Decimal from enum import EnumMeta -from functools import partial +from functools import cache, partial from importlib import import_module from ipaddress import ( IPv4Address, @@ -58,7 +58,7 @@ ParameterException, ) from polyfactory.field_meta import Null -from polyfactory.fields import BaseParam, Fixture, Ignore, IsNotPassed, PostGenerated, Require, Use +from polyfactory.fields import Fixture, Ignore, Param, PostGenerated, Require, Use from polyfactory.utils.helpers import ( flatten_annotation, get_collection_type, @@ -226,6 +226,7 @@ def __init_subclass__(cls, *args: Any, **kwargs: Any) -> None: # noqa: C901 raise ConfigurationException( msg, ) + cls._check_overlapping_param_names() if cls.__check_model__: cls._check_declared_fields_exist_in_model() else: @@ -971,7 +972,7 @@ def _check_declared_fields_exist_in_model(cls) -> None: raise ConfigurationException(error_message) @classmethod - def _handle_factory_params(cls, params: dict[str, BaseParam], **kwargs: Any) -> dict[str, Any]: + def _handle_factory_params(cls, params: dict[str, Param[Any]], **kwargs: Any) -> dict[str, Any]: """Get the factory parameters. :param params: A dict of field name to Param instances. @@ -981,18 +982,32 @@ def _handle_factory_params(cls, params: dict[str, BaseParam], **kwargs: Any) -> """ try: - return {name: param.to_value(kwargs.get(name, IsNotPassed)) for name, param in params.items()} + return {name: param.to_value(kwargs.get(name, Null)) for name, param in params.items()} except MissingParamException as e: msg = "Missing required kwargs" raise MissingBuildKwargException(msg) from e @classmethod - def get_factory_params(cls) -> dict[str, BaseParam]: + @cache + def get_factory_params(cls) -> dict[str, Param[Any]]: """Get the factory parameters. :returns: A dict of field name to Param instances. """ - return {name: item for name, item in cls.__dict__.items() if isinstance(item, BaseParam)} + return {name: item for name, item in cls.__dict__.items() if isinstance(item, Param)} + + @classmethod + def _check_overlapping_param_names(cls) -> None: + """Checks if there are overlapping param names with model fields. + + + :raises: ConfigurationException + """ + model_fields_names = {field_meta.name for field_meta in cls.get_model_fields()} + overlapping_params = set(cls.get_factory_params().keys()) & model_fields_names + if overlapping_params: + msg = f"Factory Params {', '.join(overlapping_params)} overlap with model fields" + raise ConfigurationException(msg) @classmethod def process_kwargs(cls, **kwargs: Any) -> dict[str, Any]: diff --git a/polyfactory/fields.py b/polyfactory/fields.py index 7698b984..055339d7 100644 --- a/polyfactory/fields.py +++ b/polyfactory/fields.py @@ -5,6 +5,7 @@ from typing_extensions import ParamSpec from polyfactory.exceptions import MissingParamException, ParameterException +from polyfactory.field_meta import Null T = TypeVar("T") U = TypeVar("U") @@ -117,40 +118,7 @@ def to_value(self) -> Any: raise ParameterException(msg) -class NotPassed: - """Indicates a parameter was not passed to a factory field and must be - passed at build time. - """ - - -IsNotPassed = NotPassed() - - -class BaseParam(Generic[T, U]): - """Base class for parameters. - - This class is used to pass a parameters that can be referenced by other - fields but will not be passed to the final object. - - It is generic over the type of the parameter that will be used during build - and also the method used to generate that value (e.g. as a constant or a - callable). - """ - - def to_value(self, from_build: U | NotPassed = IsNotPassed) -> T: - """Determines the value of the parameter. - - This method must be implemented in subclasses. - - :param from_build: The value passed at build time. - :returns: The value - :raises: NotImplementedError - """ - msg = "to_value must be implemented in subclasses" - raise NotImplementedError(msg) # pragma: no cover - - -class Param(Generic[T], BaseParam[T, T]): +class Param(Generic[T]): """A constant parameter that can be used by other fields but will not be passed to the final object. @@ -158,78 +126,71 @@ class Param(Generic[T], BaseParam[T, T]): be passed at build time. Otherwise, a MissingParamException will be raised. """ - __slots__ = ("param",) + __slots__ = ("is_callable", "kwargs", "param") - def __init__(self, param: T | NotPassed = IsNotPassed) -> None: + def __init__( + self, param: T | Callable[..., T] | type[Null] = Null, is_callable: bool = False, **kwargs: Any + ) -> None: """Designate a parameter. :param param: A constant or an unpassed value that can be referenced later """ - self.param = param + if param is not Null and is_callable and not callable(param): + msg = "If an object is passed to param, a callable must be passed when is_callable is True" + raise ParameterException(msg) + if not is_callable and kwargs: + msg = "kwargs can only be used with callable parameters" + raise ParameterException(msg) - def to_value(self, from_build: T | NotPassed = IsNotPassed) -> T: - """Determines the value to use at build time - - If a value was passed to the constructor, it will be used. Otherwise, the value - passed at build time will be used. If no value was passed at build time, a - MissingParamException will be raised. - - :param args: from_build: The value passed at build time (if any). - :returns: The value - :raises: MissingParamException - """ - if self.param is IsNotPassed: - if from_build is not IsNotPassed: - return cast(T, from_build) - msg = "Param value was not passed at build time" - raise MissingParamException(msg) - return cast(T, self.param) - - -class CallableParam(Generic[T], BaseParam[T, Callable[..., T]]): - """A callable parameter that can be used by other fields but will not be - passed to the final object. - - The callable may be passed optional keyword arguments via the constructor - of this class. The callable will be invoked with the passed keyword - arguments and any positional arguments passed at build time. - - If a callable for the parameter is not passed in the field's definition, it must - be passed at build time. Otherwise, a MissingParamException will be raised. - """ - - __slots__ = ( - "kwargs", - "param", - ) - - def __init__( - self, - param: Callable[..., T] | NotPassed = IsNotPassed, - **kwargs: Any, - ) -> None: - """Designate field as a callable parameter. - - :param param: A callable that will be evaluated at build time. - :param kwargs: Any kwargs to pass to the callable. - """ self.param = param + self.is_callable = is_callable self.kwargs = kwargs - def to_value(self, from_build: Callable[..., T] | NotPassed = IsNotPassed) -> T: - """Determine the value to use at build time. + def to_value(self, from_build: T | Callable[..., T] | type[Null] = Null, **kwargs: Any) -> T: + """Determines the value to use at build time If a value was passed to the constructor, it will be used. Otherwise, the value passed at build time will be used. If no value was passed at build time, a MissingParamException will be raised. - :param args: from_build: The callable passed at build time (if any). + :param args: from_build: The value passed at build time (if any). :returns: The value :raises: MissingParamException """ - if self.param is IsNotPassed: - if from_build is not IsNotPassed: - return cast(Callable[..., T], from_build)(**self.kwargs) - msg = "Param value was not passed at build time" + # If no param is passed at initialization, a value must be passed now + if self.param is Null: + # from_build was passed, so determine the value based on whether or + # not we're supposed to call a callable + if from_build is not Null: + return ( + cast(T, from_build) + if not self.is_callable + else cast(Callable[..., T], from_build)(**{**self.kwargs, **kwargs}) + ) + + # Otherwise, raise an exception + msg = ( + "Expected a parameter value to be passed at build time" + if not self.is_callable + else "Expected a callable to be passed at build time" + ) raise MissingParamException(msg) - return cast(Callable[..., T], self.param)(**self.kwargs) + # A param was passed at initialization + if self.is_callable: + # In this case, we are going to call the callable, but we can still + # override if are passed a callable at build + if from_build is not Null: + if callable(from_build): + return cast(Callable[..., T], from_build)(**{**self.kwargs, **kwargs}) + + # If we were passed a value at build that isn't a callable, raise + # an exception + msg = "The value passed at build time is not callable" + raise TypeError(msg) + + # Otherwise, return the value passed at initialization + return cast(Callable[..., T], self.param)(**{**self.kwargs, **kwargs}) + + # Inthis case, we are not using a callable, so return either the value + # passed at build time or initialization + return cast(T, self.param) if from_build is Null else cast(T, from_build) From 612201d4db0703d05ab8ae534ac00a8647f207e4 Mon Sep 17 00:00:00 2001 From: Alex Hlavinka Date: Sun, 23 Mar 2025 13:34:15 -0500 Subject: [PATCH 5/9] test(test_factory_fields.py) update tests for new Param type --- tests/test_factory_fields.py | 118 ++++++++++++++++++++++++++++++++--- 1 file changed, 109 insertions(+), 9 deletions(-) diff --git a/tests/test_factory_fields.py b/tests/test_factory_fields.py index 50c93566..0e09c4e7 100644 --- a/tests/test_factory_fields.py +++ b/tests/test_factory_fields.py @@ -8,10 +8,16 @@ from pydantic import BaseModel from polyfactory.decorators import post_generated -from polyfactory.exceptions import ConfigurationException, MissingBuildKwargException +from polyfactory.exceptions import ( + ConfigurationException, + MissingBuildKwargException, + MissingParamException, + ParameterException, +) from polyfactory.factories.dataclass_factory import DataclassFactory from polyfactory.factories.pydantic_factory import ModelFactory -from polyfactory.fields import CallableParam, Ignore, Param, PostGenerated, Require, Use +from polyfactory.field_meta import Null +from polyfactory.fields import Ignore, Param, PostGenerated, Require, Use def test_use() -> None: @@ -83,7 +89,85 @@ class MyFactory(ModelFactory): assert MyFactory.build().name is None -def test_param__from_factory() -> None: +@pytest.mark.parametrize( + "value,is_callable,kwargs", + [ + (None, False, {}), + (1, False, {}), + ("foo", False, {}), + (lambda value: value, True, {}), + (lambda value1, value2: value1 + value2, True, {}), + (lambda: "foo", True, {}), + (lambda: "foo", True, {"value": 3}), + ], +) +def test_param_init(value: Any, is_callable: bool, kwargs: dict[str, Any]) -> None: + param = Param(value, is_callable, **kwargs) # type: ignore + assert isinstance(param, Param) + assert param.param == value + assert param.is_callable == is_callable + assert param.kwargs == kwargs + + +@pytest.mark.parametrize( + "value,is_callable,kwargs", + [ + (None, True, {}), + (1, True, {}), + ("foo", True, {}), + (Null, False, {"value": 3}), + (1, False, {"value": 3}), + ], +) +def test_param_init_error(value: Any, is_callable: bool, kwargs: dict[str, Any]) -> None: + with pytest.raises( + ParameterException, + ): + Param(value, is_callable, **kwargs) + + +@pytest.mark.parametrize( + "initval,is_cabllable,initkwargs,buildval,buildkwargs,outcome", + [ + (None, False, {}, Null, {}, None), + (1, False, {}, 2, {}, 2), + ("foo", False, {}, Null, {}, "foo"), + (lambda value: value, True, {}, lambda value: value + 1, {"value": 3}, 4), + (lambda value1, value2: value1 + value2, True, {"value1": 2}, Null, {"value2": 1}, 3), + (lambda: "foo", True, {}, Null, {}, "foo"), + ], +) +def test_param_to_value( + initval: Any, + is_cabllable: bool, + initkwargs: dict[str, Any], + buildval: Any, + buildkwargs: dict[str, Any], + outcome: Any, +) -> None: + assert Param(initval, is_cabllable, **initkwargs).to_value(buildval, **buildkwargs) == outcome + + +@pytest.mark.parametrize( + "initval,is_cabllable,initkwargs,buildval,buildkwargs,exc", + [ + (Null, False, {}, Null, {}, MissingParamException), + (Null, True, {}, 1, {}, TypeError), + ], +) +def test_param_to_value_exception( + initval: Any, + is_cabllable: bool, + initkwargs: dict[str, Any], + buildval: Any, + buildkwargs: dict[str, Any], + exc: type[Exception], +) -> None: + with pytest.raises(exc): + Param(initval, is_cabllable, **initkwargs).to_value(buildval, **buildkwargs) + + +def test_param_from_factory() -> None: value: int = 3 class MyModel(BaseModel): @@ -102,7 +186,7 @@ def description(cls, length: int) -> str: assert result.description == "abc" -def test_param__from_kwargs() -> None: +def test_param_from_kwargs() -> None: value: int = 3 class MyModel(BaseModel): @@ -121,7 +205,7 @@ def description(cls, length: int) -> str: assert result.description == "abc" -def test_param__from_kwargs__missing() -> None: +def test_param_from_kwargs_missing() -> None: class MyModel(BaseModel): description: str @@ -138,13 +222,13 @@ def description(cls, length: int) -> str: MyFactory.build() -def test_callable_param__from_factory() -> None: +def test_callable_param_from_factory() -> None: class MyModel(BaseModel): description: str class MyFactory(ModelFactory): __model__ = MyModel - length = CallableParam(lambda value: value, value=3) + length = Param(lambda value: value, is_callable=True, value=3) @post_generated @classmethod @@ -155,7 +239,7 @@ def description(cls, length: int) -> str: assert result.description == "abc" -def test_callable_param__from_kwargs() -> None: +def test_callable_param_from_kwargs() -> None: value1: int = 2 value2: int = 1 @@ -164,7 +248,7 @@ class MyModel(BaseModel): class MyFactory(ModelFactory): __model__ = MyModel - length = CallableParam[int](value1=value1, value2=value2) + length = Param[int](is_callable=True, value1=value1, value2=value2) @post_generated @classmethod @@ -175,6 +259,22 @@ def description(cls, length: int) -> str: assert result.description == "abcd"[: value1 + value2] +def test_param_name_overlaps_model_field() -> None: + class MyModel(BaseModel): + name: str + other: int + + with pytest.raises(ConfigurationException) as exc: + + class MyFactory(ModelFactory): + __model__ = MyModel + name = Param[str]("foo") + other = 1 + + assert "name" in str(exc) + assert "other" not in str(exc) + + def test_post_generation() -> None: random_delta = timedelta(days=random.randint(0, 12), seconds=random.randint(13, 13000)) From 2d6de83a626e13271cc429597449331edf9f55ec Mon Sep 17 00:00:00 2001 From: Alex Hlavinka Date: Sun, 23 Mar 2025 13:35:05 -0500 Subject: [PATCH 6/9] docs(fields.rst) update documentation for new Param type --- docs/examples/fields/test_example_10.py | 6 +++--- docs/usage/fields.rst | 26 +++++++++---------------- 2 files changed, 12 insertions(+), 20 deletions(-) diff --git a/docs/examples/fields/test_example_10.py b/docs/examples/fields/test_example_10.py index be042b43..fe8f7fbd 100644 --- a/docs/examples/fields/test_example_10.py +++ b/docs/examples/fields/test_example_10.py @@ -2,7 +2,7 @@ from polyfactory.decorators import post_generated from polyfactory.factories import DataclassFactory -from polyfactory.fields import CallableParam +from polyfactory.fields import Param @dataclass @@ -14,7 +14,7 @@ class Person: class PersonFactoryWithParamValueSpecifiedInFactory(DataclassFactory[Person]): """In this factory, the next_years_age_from_calculator must be passed at build time.""" - next_years_age_from_calculator = CallableParam[int](lambda age: age + 1, age=20) + next_years_age_from_calculator = Param[int](lambda age: age + 1, is_callable=True, age=20) @post_generated @classmethod @@ -33,7 +33,7 @@ def test_factory__in_factory() -> None: class PersonFactoryWithParamValueSetAtBuild(DataclassFactory[Person]): """In this factory, the next_years_age_from_calculator must be passed at build time.""" - next_years_age_from_calculator = CallableParam[int](age=20) + next_years_age_from_calculator = Param[int](is_callable=True, age=20) @post_generated @classmethod diff --git a/docs/usage/fields.rst b/docs/usage/fields.rst index 4df063fc..d9bb8257 100644 --- a/docs/usage/fields.rst +++ b/docs/usage/fields.rst @@ -82,32 +82,24 @@ name in the values dictionary. The ``Param`` Field ------------------- -The :class:`Param ` class denotes a constant parameter that can be referenced by other fields at build but is not mapped to the final object. This is useful for passing values needed by other factory fields but that are not part of object being built. +The :class:`Param ` class denotes a parameter that can be referenced by other fields at build but whose value is not set on the final object. This is useful for passing values needed by other factory fields but that are not part of object being built. -The Param type can either accept a constant value at the definition of the factory, or its value can be set at build time. +A Param type can be either a constant or a callable. If a callable is used, it will be executed at the beginning of build and its return value will be used as the value for the field. Optional keyword arguments may be passed to the callable as part of the field definition on the factory. Any additional keyword arguments passed to the Param constructor will also not be mapped into the final object. -If neither a value is provided at the definition of the factory nor at build time, an exception will be raised. +The Param type allows for flexibility in that it can either accept a value at the definition of the factory, or its value can be set at build time. If a value is provided at build time, it will take precedence over the value provided at the definition of the factory (if any). + +If neither a value is provided at the definition of the factory nor at build time, an exception will be raised. Likewise, a Param cannot have the same name as any other model field. .. literalinclude:: /examples/fields/test_example_9.py - :caption: Using the ``Param`` field + :caption: Using the ``Param`` field with a constant :language: python - -The ``CallableParam`` Field ---------------------------- - -The :class:`CallableParam ` class denotes a callable parameter with a return value that may be referenced by other fields during build but is not mapped to the final object. Optional keyword arguments may be passed to the callable as part of the field definition on the factory. Any additional keyword arguments passed to the build method will also not be passed to the final object. - -The CallableParam type can either accept a callable provided at the definition of the factory, or its value can be passed at build time. The callable is executed at the beginning of build. - -If neither a value is provided at the definition of the factory nor at build time, an exception will be raised. - -The difference between a Param and a CallableParam is that the CallableParam is always executed at build time. If you need to pass an unmapped callable to the factory that should not automatically be executed at build time, use a Param. - .. literalinclude:: /examples/fields/test_example_10.py - :caption: Using the ``CallableParam`` field + :caption: Using the ``Param`` field with a callable :language: python + + Factories as Fields --------------------------- From e1f094ee3beca7bd84f8356fbcc5a96a42b60e87 Mon Sep 17 00:00:00 2001 From: Alex Hlavinka Date: Sun, 23 Mar 2025 14:02:45 -0500 Subject: [PATCH 7/9] fix(get_factory_params): fix issue caused by compatibility of cache with older versions of python --- polyfactory/factories/base.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/polyfactory/factories/base.py b/polyfactory/factories/base.py index f3fc9422..34bf70c3 100644 --- a/polyfactory/factories/base.py +++ b/polyfactory/factories/base.py @@ -8,7 +8,7 @@ from datetime import date, datetime, time, timedelta from decimal import Decimal from enum import EnumMeta -from functools import cache, partial +from functools import lru_cache, partial from importlib import import_module from ipaddress import ( IPv4Address, @@ -1043,7 +1043,7 @@ def _handle_factory_params(cls, params: dict[str, Param[Any]], **kwargs: Any) -> raise MissingBuildKwargException(msg) from e @classmethod - @cache + @lru_cache(maxsize=None) def get_factory_params(cls) -> dict[str, Param[Any]]: """Get the factory parameters. From 9035739723d846cf38a57846479e8bb107a52050 Mon Sep 17 00:00:00 2001 From: Alex Hlavinka Date: Sun, 23 Mar 2025 14:09:47 -0500 Subject: [PATCH 8/9] fix(test_factory_fields) update with typings accepted by previous versions of python --- tests/test_factory_fields.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/tests/test_factory_fields.py b/tests/test_factory_fields.py index 0e09c4e7..5f644787 100644 --- a/tests/test_factory_fields.py +++ b/tests/test_factory_fields.py @@ -1,7 +1,7 @@ import random from dataclasses import dataclass from datetime import datetime, timedelta -from typing import Any, ClassVar, List, Optional, Union +from typing import Any, ClassVar, Dict, List, Optional, Union import pytest @@ -101,7 +101,7 @@ class MyFactory(ModelFactory): (lambda: "foo", True, {"value": 3}), ], ) -def test_param_init(value: Any, is_callable: bool, kwargs: dict[str, Any]) -> None: +def test_param_init(value: Any, is_callable: bool, kwargs: Dict[str, Any]) -> None: param = Param(value, is_callable, **kwargs) # type: ignore assert isinstance(param, Param) assert param.param == value @@ -119,7 +119,7 @@ def test_param_init(value: Any, is_callable: bool, kwargs: dict[str, Any]) -> No (1, False, {"value": 3}), ], ) -def test_param_init_error(value: Any, is_callable: bool, kwargs: dict[str, Any]) -> None: +def test_param_init_error(value: Any, is_callable: bool, kwargs: Dict[str, Any]) -> None: with pytest.raises( ParameterException, ): @@ -140,9 +140,9 @@ def test_param_init_error(value: Any, is_callable: bool, kwargs: dict[str, Any]) def test_param_to_value( initval: Any, is_cabllable: bool, - initkwargs: dict[str, Any], + initkwargs: Dict[str, Any], buildval: Any, - buildkwargs: dict[str, Any], + buildkwargs: Dict[str, Any], outcome: Any, ) -> None: assert Param(initval, is_cabllable, **initkwargs).to_value(buildval, **buildkwargs) == outcome @@ -158,9 +158,9 @@ def test_param_to_value( def test_param_to_value_exception( initval: Any, is_cabllable: bool, - initkwargs: dict[str, Any], + initkwargs: Dict[str, Any], buildval: Any, - buildkwargs: dict[str, Any], + buildkwargs: Dict[str, Any], exc: type[Exception], ) -> None: with pytest.raises(exc): From 8e0bc8f6c93c904a8bedddb5d0b48b38f5fc7ee8 Mon Sep 17 00:00:00 2001 From: Alex Hlavinka Date: Sun, 23 Mar 2025 14:14:07 -0500 Subject: [PATCH 9/9] fix(test_factory_fields) update typing for compatibility with previous versions of python --- tests/test_factory_fields.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/test_factory_fields.py b/tests/test_factory_fields.py index 5f644787..800e7695 100644 --- a/tests/test_factory_fields.py +++ b/tests/test_factory_fields.py @@ -1,7 +1,7 @@ import random from dataclasses import dataclass from datetime import datetime, timedelta -from typing import Any, ClassVar, Dict, List, Optional, Union +from typing import Any, ClassVar, Dict, List, Optional, Type, Union import pytest @@ -161,7 +161,7 @@ def test_param_to_value_exception( initkwargs: Dict[str, Any], buildval: Any, buildkwargs: Dict[str, Any], - exc: type[Exception], + exc: Type[Exception], ) -> None: with pytest.raises(exc): Param(initval, is_cabllable, **initkwargs).to_value(buildval, **buildkwargs)