Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 13 additions & 3 deletions polyfactory/factories/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@
handle_constrained_collection,
handle_constrained_mapping,
)
from polyfactory.value_generators.constrained_dates import handle_constrained_date
from polyfactory.value_generators.constrained_dates import handle_constrained_date, handle_constrained_datetime
from polyfactory.value_generators.constrained_numbers import (
handle_constrained_decimal,
handle_constrained_float,
Expand Down Expand Up @@ -624,15 +624,15 @@ def create_factory(
)

@classmethod
def get_constrained_field_value( # noqa: C901, PLR0911
def get_constrained_field_value( # noqa: C901, PLR0911, PLR0912
cls,
annotation: Any,
field_meta: FieldMeta,
field_build_parameters: Any | None = None,
build_context: BuildContext | None = None,
) -> Any:
constraints = cast("Constraints", field_meta.constraints)
try:
constraints = cast("Constraints", field_meta.constraints)
if is_safe_subclass(annotation, float):
return handle_constrained_float(
random=cls.__random__,
Expand Down Expand Up @@ -705,6 +705,16 @@ def get_constrained_field_value( # noqa: C901, PLR0911
build_context=build_context,
)

if is_safe_subclass(annotation, datetime):
return handle_constrained_datetime(
faker=cls.__faker__,
ge=cast("Any", constraints.get("ge")),
gt=cast("Any", constraints.get("gt")),
le=cast("Any", constraints.get("le")),
lt=cast("Any", constraints.get("lt")),
tz=cast("Any", constraints.get("tz")),
)

if is_safe_subclass(annotation, date):
return handle_constrained_date(
faker=cls.__faker__,
Expand Down
42 changes: 38 additions & 4 deletions polyfactory/value_generators/constrained_dates.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,15 +27,49 @@ def handle_constrained_date(
:returns: A date instance.
"""
start_date = datetime.now(tz=tz).date() - timedelta(days=100)
if ge:
if ge is not None:
start_date = ge
elif gt:
elif gt is not None:
start_date = gt + timedelta(days=1)

end_date = datetime.now(tz=timezone.utc).date() + timedelta(days=100)
if le:
if le is not None:
end_date = le
elif lt:
elif lt is not None:
end_date = lt - timedelta(days=1)

return faker.date_between(start_date=start_date, end_date=end_date)


def handle_constrained_datetime(
faker: Faker,
ge: datetime | None = None,
gt: datetime | None = None,
le: datetime | None = None,
lt: datetime | None = None,
tz: tzinfo | None = None,
) -> datetime:
"""Generates a datetime value fulfilling the expected constraints.

:param faker: An instance of faker.
:param lt: Less than value.
:param le: Less than or equal value.
:param gt: Greater than value.
:param ge: Greater than or equal value.
:param tz: A timezone. If not provided, infers from constraint values.

:returns: A datetime instance.
"""
start_datetime = datetime.now(tz=tz) - timedelta(days=100)
if ge:
start_datetime = ge
elif gt:
start_datetime = gt + timedelta(seconds=1)

end_datetime = start_datetime + timedelta(days=30)
if le is not None:
end_datetime = le
elif lt is not None:
end_datetime = lt - timedelta(seconds=1)

return faker.date_time_between(start_date=start_datetime, end_date=end_datetime, tzinfo=tz)
10 changes: 9 additions & 1 deletion tests/constraints/test_date_constraints.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,15 @@
dates(max_value=date.today() - timedelta(days=3)),
dates(min_value=date.today()),
)
@pytest.mark.parametrize(("start", "end"), (("ge", "le"), ("gt", "lt"), ("ge", "lt"), ("gt", "le")))
@pytest.mark.parametrize(
("start", "end"),
(
("ge", "le"),
("gt", "lt"),
("ge", "lt"),
("gt", "le"),
),
)
def test_handle_constrained_date(
start: Optional[str],
end: Optional[str],
Expand Down
82 changes: 82 additions & 0 deletions tests/constraints/test_datetime_constraints.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
"""Test datetime constraints, including Issue #734."""

import contextlib
from datetime import datetime, timedelta, timezone
from typing import Annotated, Optional

import pytest
from annotated_types import Timezone
from hypothesis import given
from hypothesis.strategies import datetimes

from pydantic import BaseModel, Field, __version__

with contextlib.suppress(ImportError):
from pydantic import BeforeValidator


from polyfactory.factories.pydantic_factory import ModelFactory


@given(
datetimes(min_value=datetime(1900, 1, 1), max_value=datetime.now() - timedelta(days=3)),
datetimes(min_value=datetime.now(), max_value=datetime(2100, 1, 1)),
)
@pytest.mark.parametrize(
("start", "end"),
(
("ge", "le"),
("gt", "lt"),
("ge", "lt"),
("gt", "le"),
),
)
def test_handle_constrained_datetime(
start: Optional[str],
end: Optional[str],
start_datetime: datetime,
end_datetime: datetime,
) -> None:
"""Test that constrained datetimes are generated correctly."""
if start_datetime == end_datetime:
return

kwargs: dict[str, datetime] = {}
if start:
kwargs[start] = start_datetime
if end:
kwargs[end] = end_datetime

class MyModel(BaseModel):
value: datetime = Field(**kwargs) # type: ignore

class MyFactory(ModelFactory[MyModel]): ...

result = MyFactory.build()

assert result.value
assert isinstance(result.value, datetime), "Should be datetime.datetime, not date"
assert result.value >= start_datetime
assert result.value <= end_datetime


@pytest.mark.skipif(__version__.startswith("1"), reason="Pydantic v2 required")
def test_annotated_datetime_with_validator_and_constraint() -> None:
def validate_datetime(value: datetime) -> datetime:
"""Validator that expects a datetime object with timezone info."""
assert isinstance(value, datetime), f"Expected datetime.datetime, got {type(value)}"
assert value.tzinfo == timezone.utc, f"Expected UTC timezone, got {value.tzinfo}"
return value

ValidatedDatetime = Annotated[datetime, BeforeValidator(validate_datetime), Timezone(tz=timezone.utc)]
minimum_datetime = datetime(2030, 1, 1, tzinfo=timezone.utc)

class MyModel(BaseModel):
dt: ValidatedDatetime = Field(gt=minimum_datetime) # pyright: ignore[reportInvalidTypeForm]

class MyModelFactory(ModelFactory[MyModel]): ...

instance = MyModelFactory.build()
assert isinstance(instance.dt, datetime), "Should be datetime.datetime"
assert instance.dt.tzinfo == timezone.utc, "Should have UTC timezone"
assert instance.dt > minimum_datetime, "Should respect gt constraint"
Loading