Skip to content

Commit

Permalink
refactor(prices): reorganise schema inheritance to validate price upd…
Browse files Browse the repository at this point in the history
…ates (#342)
  • Loading branch information
raphodn authored Jun 27, 2024
1 parent 6d3c84d commit 64aac06
Show file tree
Hide file tree
Showing 5 changed files with 125 additions and 104 deletions.
6 changes: 2 additions & 4 deletions app/crud.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,9 @@
from app.schemas import (
LocationCreate,
LocationFilter,
PriceBasicUpdatableFields,
PriceCreate,
PriceFilter,
PriceUpdate,
ProductCreate,
ProductFilter,
ProductFull,
Expand Down Expand Up @@ -345,9 +345,7 @@ def delete_price(db: Session, db_price: Price) -> bool:
return True


def update_price(
db: Session, price: Price, new_values: PriceBasicUpdatableFields
) -> Price:
def update_price(db: Session, price: Price, new_values: PriceUpdate) -> Price:
new_values_cleaned = new_values.model_dump(exclude_unset=True)
for key in new_values_cleaned:
setattr(price, key, new_values_cleaned[key])
Expand Down
4 changes: 2 additions & 2 deletions app/routers/prices.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ def get_prices(
status_code=status.HTTP_201_CREATED,
)
def create_price(
price: schemas.PriceCreateWithValidation,
price: schemas.PriceCreate,
background_tasks: BackgroundTasks,
current_user: schemas.UserCreate = Depends(get_current_user),
app_name: str | None = None,
Expand Down Expand Up @@ -77,7 +77,7 @@ def create_price(
)
def update_price(
price_id: int,
price_new_values: schemas.PriceBasicUpdatableFields,
price_new_values: schemas.PriceUpdate,
current_user: schemas.UserCreate = Depends(get_current_user),
db: Session = Depends(get_db),
) -> Price:
Expand Down
168 changes: 92 additions & 76 deletions app/schemas.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import datetime
from typing import Optional
from copy import deepcopy
from typing import Any, Optional, Tuple, Type

from fastapi_filter.contrib.sqlalchemy import Filter
from openfoodfacts import Flavor
Expand All @@ -9,17 +10,43 @@
BaseModel,
ConfigDict,
Field,
create_model,
field_validator,
model_validator,
)
from pydantic.fields import FieldInfo

from app.enums import CurrencyEnum, LocationOSMEnum, PricePerEnum, ProofTypeEnum
from app.models import Location, Price, Product, Proof, User

# Session
# ------------------------------------------------------------------------------

def partial_model(model: Type[BaseModel]):
"""
Custom decorator to set all fields of a model as optional.
https://stackoverflow.com/a/76560886/4293684
"""

def make_field_optional(
field: FieldInfo, default: Any = None
) -> Tuple[Any, FieldInfo]:
new = deepcopy(field)
new.default = default
new.annotation = Optional[field.annotation] # type: ignore
return new.annotation, new

return create_model(
f"Partial{model.__name__}",
__base__=model,
__module__=model.__module__,
**{
field_name: make_field_optional(field_info)
for field_name, field_info in model.model_fields.items()
},
)


# Session
# ------------------------------------------------------------------------------
class SessionBase(BaseModel):
model_config = ConfigDict(from_attributes=True)

Expand Down Expand Up @@ -233,9 +260,66 @@ class Config:

# Price
# ------------------------------------------------------------------------------
class PriceCreate(BaseModel):
model_config = ConfigDict(from_attributes=True, arbitrary_types_allowed=True)
class PriceBase(BaseModel):
model_config = ConfigDict(
from_attributes=True, arbitrary_types_allowed=True, extra="forbid"
)

price: float = Field(
gt=0,
description="price of the product, without its currency, taxes included.",
examples=[1.99],
)
price_is_discounted: bool = Field(
default=False,
description="true if the price is discounted.",
examples=[True],
)
price_without_discount: float | None = Field(
default=None,
description="price of the product without discount, without its currency, taxes included. "
"If the product is not discounted, this field must be null. ",
examples=[2.99],
)
price_per: PricePerEnum | None = Field(
default=PricePerEnum.KILOGRAM,
description="""if the price is about a barcode-less product
(if `category_tag` is provided), this field must be set to `KILOGRAM`
or `UNIT` (KILOGRAM by default).
This field is set to null and ignored if `product_code` is provided.
""",
examples=["KILOGRAM", "UNIT"],
)
currency: CurrencyEnum = Field(
description="currency of the price, as a string. "
"The currency must be a valid currency code. "
"See https://en.wikipedia.org/wiki/ISO_4217 for a list of valid currency codes.",
examples=["EUR", "USD"],
)
date: datetime.date = Field(
description="date when the product was bought.", examples=["2024-01-01"]
)

@model_validator(mode="after")
def check_price_discount(self): # type: ignore
"""
Check that:
- `price_is_discounted` is true if `price_without_discount` is passed
- `price_without_discount` is greater than `price`
"""
if self.price_without_discount is not None:
if not self.price_is_discounted:
raise ValueError(
"`price_is_discounted` must be true if `price_without_discount` is filled"
)
if self.price_without_discount <= self.price:
raise ValueError(
"`price_without_discount` must be greater than `price`"
)
return self


class PriceCreate(PriceBase):
product_code: str | None = Field(
default=None,
min_length=1,
Expand Down Expand Up @@ -291,37 +375,6 @@ class PriceCreate(BaseModel):
If one of the origins is not valid, the price will be rejected.""",
examples=[["en:france"], ["en:california"]],
)
price: float = Field(
gt=0,
description="price of the product, without its currency, taxes included.",
examples=[1.99],
)
price_is_discounted: bool = Field(
default=False,
description="true if the price is discounted.",
examples=[True],
)
price_without_discount: float | None = Field(
default=None,
description="price of the product without discount, without its currency, taxes included. "
"If the product is not discounted, this field must be null. ",
examples=[2.99],
)
price_per: PricePerEnum | None = Field(
default=PricePerEnum.KILOGRAM,
description="""if the price is about a barcode-less product
(if `category_tag` is provided), this field must be set to `KILOGRAM`
or `UNIT` (KILOGRAM by default).
This field is set to null and ignored if `product_code` is provided.
""",
examples=["KILOGRAM", "UNIT"],
)
currency: CurrencyEnum = Field(
description="currency of the price, as a string. "
"The currency must be a valid currency code. "
"See https://en.wikipedia.org/wiki/ISO_4217 for a list of valid currency codes.",
examples=["EUR", "USD"],
)
location_osm_id: int = Field(
gt=0,
description="ID of the location in OpenStreetMap: the store where the product was bought.",
Expand All @@ -333,9 +386,6 @@ class PriceCreate(BaseModel):
"information about the store using the ID.",
examples=["NODE", "WAY", "RELATION"],
)
date: datetime.date = Field(
description="date when the product was bought.", examples=["2024-01-01"]
)
proof_id: int | None = Field(
default=None,
description="ID of the proof, if any. The proof is a file (receipt or price tag image) "
Expand All @@ -345,15 +395,6 @@ class PriceCreate(BaseModel):
examples=[15],
)


class PriceCreateWithValidation(PriceCreate):
"""A version of `PriceCreate` with taxonomy validations.
These validations are not done in the `PriceCreate` model because they
they are time-consuming and only necessary when creating a price from
the API.
"""

@field_validator("labels_tags")
def labels_tags_is_valid(cls, v: list[str] | None) -> list[str] | None:
if v is not None:
Expand Down Expand Up @@ -421,35 +462,10 @@ def set_price_per_to_null_if_barcode(self): # type: ignore
self.price_per = None
return self

@model_validator(mode="after")
def check_price_discount(self): # type: ignore
"""
Check that:
- `price_is_discounted` is true if `price_without_discount` is passed
- `price_without_discount` is greater than `price`
"""
if self.price_without_discount is not None:
if not self.price_is_discounted:
raise ValueError(
"`price_is_discounted` must be true if `price_without_discount` is filled"
)
if self.price_without_discount <= self.price:
raise ValueError(
"`price_without_discount` must be greater than `price`"
)
return self


class PriceBasicUpdatableFields(BaseModel):
price: float | None = None
price_is_discounted: bool | None = None
price_without_discount: float | None = None
price_per: PricePerEnum | None = None
currency: CurrencyEnum | None = None
date: datetime.date | None = None

class Config:
extra = "forbid"
@partial_model
class PriceUpdate(PriceBase):
pass


class PriceFull(PriceCreate):
Expand Down
33 changes: 20 additions & 13 deletions tests/integration/test_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -628,7 +628,7 @@ def test_get_prices_filters(db_session, user_session: SessionModel, clean_prices
user_session.user,
)
crud.create_price(
db_session, PRICE_1.model_copy(update={"price": 5.10}), user_session.user
db_session, PRICE_1.model_copy(update={"price": 4.1}), user_session.user
)
crud.create_price(
db_session,
Expand All @@ -653,8 +653,8 @@ def test_get_prices_filters(db_session, user_session: SessionModel, clean_prices
response = client.get("/api/v1/prices?category_tag=en:tomatoes")
assert response.status_code == 200
assert len(response.json()["items"]) == 1
# 1 price with price > 5
response = client.get("/api/v1/prices?price__gt=5")
# 1 price with price > 4
response = client.get("/api/v1/prices?price__gt=4")
assert response.status_code == 200
assert len(response.json()["items"]) == 1
# 1 price with currency USD
Expand Down Expand Up @@ -694,7 +694,7 @@ def test_update_price(
# create price
db_price = crud.create_price(db_session, PRICE_1, user_session.user)

new_price = 5.5
new_price = 4.5
PRICE_UPDATE_PARTIAL = {"price": new_price}
# without authentication
response = client.patch(f"/api/v1/prices/{db_price.id}")
Expand Down Expand Up @@ -741,14 +741,21 @@ def test_update_price(
response.json()["price_is_discounted"] != PRICE_1.price_is_discounted
) # False
assert response.json()["price_without_discount"] is None
# with authentication and price owner but extra fields
PRICE_UPDATE_PARTIAL_WRONG = {**PRICE_UPDATE_PARTIAL, "proof_id": 1}
response = client.patch(
f"/api/v1/prices/{db_price.id}",
headers={"Authorization": f"Bearer {user_session.token}"},
json=jsonable_encoder(PRICE_UPDATE_PARTIAL_WRONG),
)
assert response.status_code == 422
# with authentication and price owner but validation error
PRICE_UPDATE_PARTIAL_WRONG_LIST = [
{**PRICE_UPDATE_PARTIAL, "proof_id": 1}, # extra field
{**PRICE_UPDATE_PARTIAL, "price": -1}, # price negative
{"price_without_discount": 3, "price_is_discounted": False}, # incoherence
{"price_without_discount": 3} # price_without_discount < price
# {**PRICE_UPDATE_PARTIAL, "price_per": "UNIT"}
]
for PRICE_UPDATE_PARTIAL_WRONG in PRICE_UPDATE_PARTIAL_WRONG_LIST:
response = client.patch(
f"/api/v1/prices/{db_price.id}",
headers={"Authorization": f"Bearer {user_session.token}"},
json=jsonable_encoder(PRICE_UPDATE_PARTIAL_WRONG),
)
assert response.status_code == 422


def test_update_price_moderator(
Expand All @@ -757,7 +764,7 @@ def test_update_price_moderator(
# create price
db_price = crud.create_price(db_session, PRICE_1, user_session.user)

new_price = 5.5
new_price = 4.5
PRICE_UPDATE_PARTIAL = {"price": new_price}

# user_1 is moderator, not owner
Expand Down
Loading

0 comments on commit 64aac06

Please sign in to comment.