Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor(prices): reorganise schema inheritance to validate price updates #342

Merged
merged 3 commits into from
Jun 27, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 2 additions & 4 deletions app/crud.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,9 @@
from app.schemas import (
LocationCreate,
LocationFilter,
PriceBasicUpdatableFields,
PriceCreate,
PriceFilter,
PriceUpdate,
ProductCreate,
ProductFilter,
ProductFull,
Expand Down Expand Up @@ -345,9 +345,7 @@ def delete_price(db: Session, db_price: Price) -> bool:
return True


def update_price(
db: Session, price: Price, new_values: PriceBasicUpdatableFields
) -> Price:
def update_price(db: Session, price: Price, new_values: PriceUpdate) -> Price:
new_values_cleaned = new_values.model_dump(exclude_unset=True)
for key in new_values_cleaned:
setattr(price, key, new_values_cleaned[key])
Expand Down
4 changes: 2 additions & 2 deletions app/routers/prices.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ def get_prices(
status_code=status.HTTP_201_CREATED,
)
def create_price(
price: schemas.PriceCreateWithValidation,
price: schemas.PriceCreate,
background_tasks: BackgroundTasks,
current_user: schemas.UserCreate = Depends(get_current_user),
app_name: str | None = None,
Expand Down Expand Up @@ -77,7 +77,7 @@ def create_price(
)
def update_price(
price_id: int,
price_new_values: schemas.PriceBasicUpdatableFields,
price_new_values: schemas.PriceUpdate,
current_user: schemas.UserCreate = Depends(get_current_user),
db: Session = Depends(get_db),
) -> Price:
Expand Down
168 changes: 92 additions & 76 deletions app/schemas.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import datetime
from typing import Optional
from copy import deepcopy
from typing import Any, Optional, Tuple, Type

from fastapi_filter.contrib.sqlalchemy import Filter
from openfoodfacts import Flavor
Expand All @@ -9,17 +10,43 @@
BaseModel,
ConfigDict,
Field,
create_model,
field_validator,
model_validator,
)
from pydantic.fields import FieldInfo

from app.enums import CurrencyEnum, LocationOSMEnum, PricePerEnum, ProofTypeEnum
from app.models import Location, Price, Product, Proof, User

# Session
# ------------------------------------------------------------------------------

def partial_model(model: Type[BaseModel]):
"""
Custom decorator to set all fields of a model as optional.
https://stackoverflow.com/a/76560886/4293684
"""

def make_field_optional(
field: FieldInfo, default: Any = None
) -> Tuple[Any, FieldInfo]:
new = deepcopy(field)
new.default = default
new.annotation = Optional[field.annotation] # type: ignore
return new.annotation, new

return create_model(
f"Partial{model.__name__}",
__base__=model,
__module__=model.__module__,
**{
field_name: make_field_optional(field_info)
for field_name, field_info in model.model_fields.items()
},
)


# Session
# ------------------------------------------------------------------------------
class SessionBase(BaseModel):
model_config = ConfigDict(from_attributes=True)

Expand Down Expand Up @@ -233,9 +260,66 @@ class Config:

# Price
# ------------------------------------------------------------------------------
class PriceCreate(BaseModel):
model_config = ConfigDict(from_attributes=True, arbitrary_types_allowed=True)
class PriceBase(BaseModel):
model_config = ConfigDict(
from_attributes=True, arbitrary_types_allowed=True, extra="forbid"
)

price: float = Field(
gt=0,
description="price of the product, without its currency, taxes included.",
examples=[1.99],
)
price_is_discounted: bool = Field(
default=False,
description="true if the price is discounted.",
examples=[True],
)
price_without_discount: float | None = Field(
default=None,
description="price of the product without discount, without its currency, taxes included. "
"If the product is not discounted, this field must be null. ",
examples=[2.99],
)
price_per: PricePerEnum | None = Field(
default=PricePerEnum.KILOGRAM,
description="""if the price is about a barcode-less product
(if `category_tag` is provided), this field must be set to `KILOGRAM`
or `UNIT` (KILOGRAM by default).
This field is set to null and ignored if `product_code` is provided.
""",
examples=["KILOGRAM", "UNIT"],
)
currency: CurrencyEnum = Field(
description="currency of the price, as a string. "
"The currency must be a valid currency code. "
"See https://en.wikipedia.org/wiki/ISO_4217 for a list of valid currency codes.",
examples=["EUR", "USD"],
)
date: datetime.date = Field(
description="date when the product was bought.", examples=["2024-01-01"]
)

@model_validator(mode="after")
def check_price_discount(self): # type: ignore
"""
Check that:
- `price_is_discounted` is true if `price_without_discount` is passed
- `price_without_discount` is greater than `price`
"""
if self.price_without_discount is not None:
if not self.price_is_discounted:
raise ValueError(
"`price_is_discounted` must be true if `price_without_discount` is filled"
)
if self.price_without_discount <= self.price:
raise ValueError(
"`price_without_discount` must be greater than `price`"
)
return self


class PriceCreate(PriceBase):
product_code: str | None = Field(
default=None,
min_length=1,
Expand Down Expand Up @@ -291,37 +375,6 @@ class PriceCreate(BaseModel):
If one of the origins is not valid, the price will be rejected.""",
examples=[["en:france"], ["en:california"]],
)
price: float = Field(
gt=0,
description="price of the product, without its currency, taxes included.",
examples=[1.99],
)
price_is_discounted: bool = Field(
default=False,
description="true if the price is discounted.",
examples=[True],
)
price_without_discount: float | None = Field(
default=None,
description="price of the product without discount, without its currency, taxes included. "
"If the product is not discounted, this field must be null. ",
examples=[2.99],
)
price_per: PricePerEnum | None = Field(
default=PricePerEnum.KILOGRAM,
description="""if the price is about a barcode-less product
(if `category_tag` is provided), this field must be set to `KILOGRAM`
or `UNIT` (KILOGRAM by default).
This field is set to null and ignored if `product_code` is provided.
""",
examples=["KILOGRAM", "UNIT"],
)
currency: CurrencyEnum = Field(
description="currency of the price, as a string. "
"The currency must be a valid currency code. "
"See https://en.wikipedia.org/wiki/ISO_4217 for a list of valid currency codes.",
examples=["EUR", "USD"],
)
location_osm_id: int = Field(
gt=0,
description="ID of the location in OpenStreetMap: the store where the product was bought.",
Expand All @@ -333,9 +386,6 @@ class PriceCreate(BaseModel):
"information about the store using the ID.",
examples=["NODE", "WAY", "RELATION"],
)
date: datetime.date = Field(
description="date when the product was bought.", examples=["2024-01-01"]
)
proof_id: int | None = Field(
default=None,
description="ID of the proof, if any. The proof is a file (receipt or price tag image) "
Expand All @@ -345,15 +395,6 @@ class PriceCreate(BaseModel):
examples=[15],
)


class PriceCreateWithValidation(PriceCreate):
"""A version of `PriceCreate` with taxonomy validations.

These validations are not done in the `PriceCreate` model because they
they are time-consuming and only necessary when creating a price from
the API.
"""

@field_validator("labels_tags")
def labels_tags_is_valid(cls, v: list[str] | None) -> list[str] | None:
if v is not None:
Expand Down Expand Up @@ -421,35 +462,10 @@ def set_price_per_to_null_if_barcode(self): # type: ignore
self.price_per = None
return self

@model_validator(mode="after")
def check_price_discount(self): # type: ignore
"""
Check that:
- `price_is_discounted` is true if `price_without_discount` is passed
- `price_without_discount` is greater than `price`
"""
if self.price_without_discount is not None:
if not self.price_is_discounted:
raise ValueError(
"`price_is_discounted` must be true if `price_without_discount` is filled"
)
if self.price_without_discount <= self.price:
raise ValueError(
"`price_without_discount` must be greater than `price`"
)
return self


class PriceBasicUpdatableFields(BaseModel):
price: float | None = None
price_is_discounted: bool | None = None
price_without_discount: float | None = None
price_per: PricePerEnum | None = None
currency: CurrencyEnum | None = None
date: datetime.date | None = None

class Config:
extra = "forbid"
@partial_model
class PriceUpdate(PriceBase):
pass


class PriceFull(PriceCreate):
Expand Down
33 changes: 20 additions & 13 deletions tests/integration/test_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -628,7 +628,7 @@ def test_get_prices_filters(db_session, user_session: SessionModel, clean_prices
user_session.user,
)
crud.create_price(
db_session, PRICE_1.model_copy(update={"price": 5.10}), user_session.user
db_session, PRICE_1.model_copy(update={"price": 4.1}), user_session.user
)
crud.create_price(
db_session,
Expand All @@ -653,8 +653,8 @@ def test_get_prices_filters(db_session, user_session: SessionModel, clean_prices
response = client.get("/api/v1/prices?category_tag=en:tomatoes")
assert response.status_code == 200
assert len(response.json()["items"]) == 1
# 1 price with price > 5
response = client.get("/api/v1/prices?price__gt=5")
# 1 price with price > 4
response = client.get("/api/v1/prices?price__gt=4")
assert response.status_code == 200
assert len(response.json()["items"]) == 1
# 1 price with currency USD
Expand Down Expand Up @@ -694,7 +694,7 @@ def test_update_price(
# create price
db_price = crud.create_price(db_session, PRICE_1, user_session.user)

new_price = 5.5
new_price = 4.5
PRICE_UPDATE_PARTIAL = {"price": new_price}
# without authentication
response = client.patch(f"/api/v1/prices/{db_price.id}")
Expand Down Expand Up @@ -741,14 +741,21 @@ def test_update_price(
response.json()["price_is_discounted"] != PRICE_1.price_is_discounted
) # False
assert response.json()["price_without_discount"] is None
# with authentication and price owner but extra fields
PRICE_UPDATE_PARTIAL_WRONG = {**PRICE_UPDATE_PARTIAL, "proof_id": 1}
response = client.patch(
f"/api/v1/prices/{db_price.id}",
headers={"Authorization": f"Bearer {user_session.token}"},
json=jsonable_encoder(PRICE_UPDATE_PARTIAL_WRONG),
)
assert response.status_code == 422
# with authentication and price owner but validation error
PRICE_UPDATE_PARTIAL_WRONG_LIST = [
{**PRICE_UPDATE_PARTIAL, "proof_id": 1}, # extra field
{**PRICE_UPDATE_PARTIAL, "price": -1}, # price negative
{"price_without_discount": 3, "price_is_discounted": False}, # incoherence
{"price_without_discount": 3} # price_without_discount < price
# {**PRICE_UPDATE_PARTIAL, "price_per": "UNIT"}
]
for PRICE_UPDATE_PARTIAL_WRONG in PRICE_UPDATE_PARTIAL_WRONG_LIST:
response = client.patch(
f"/api/v1/prices/{db_price.id}",
headers={"Authorization": f"Bearer {user_session.token}"},
json=jsonable_encoder(PRICE_UPDATE_PARTIAL_WRONG),
)
assert response.status_code == 422


def test_update_price_moderator(
Expand All @@ -757,7 +764,7 @@ def test_update_price_moderator(
# create price
db_price = crud.create_price(db_session, PRICE_1, user_session.user)

new_price = 5.5
new_price = 4.5
PRICE_UPDATE_PARTIAL = {"price": new_price}

# user_1 is moderator, not owner
Expand Down
Loading
Loading