diff --git a/app/crud.py b/app/crud.py index f3b7fa88..30de5965 100644 --- a/app/crud.py +++ b/app/crud.py @@ -18,9 +18,9 @@ from app.schemas import ( LocationCreate, LocationFilter, - PriceBasicUpdatableFields, PriceCreate, PriceFilter, + PriceUpdate, ProductCreate, ProductFilter, ProductFull, @@ -345,9 +345,7 @@ def delete_price(db: Session, db_price: Price) -> bool: return True -def update_price( - db: Session, price: Price, new_values: PriceBasicUpdatableFields -) -> Price: +def update_price(db: Session, price: Price, new_values: PriceUpdate) -> Price: new_values_cleaned = new_values.model_dump(exclude_unset=True) for key in new_values_cleaned: setattr(price, key, new_values_cleaned[key]) diff --git a/app/routers/prices.py b/app/routers/prices.py index 589ebe7b..23a23881 100644 --- a/app/routers/prices.py +++ b/app/routers/prices.py @@ -29,7 +29,7 @@ def get_prices( status_code=status.HTTP_201_CREATED, ) def create_price( - price: schemas.PriceCreateWithValidation, + price: schemas.PriceCreate, background_tasks: BackgroundTasks, current_user: schemas.UserCreate = Depends(get_current_user), app_name: str | None = None, @@ -77,7 +77,7 @@ def create_price( ) def update_price( price_id: int, - price_new_values: schemas.PriceBasicUpdatableFields, + price_new_values: schemas.PriceUpdate, current_user: schemas.UserCreate = Depends(get_current_user), db: Session = Depends(get_db), ) -> Price: diff --git a/app/schemas.py b/app/schemas.py index 6fd62b8b..0281960f 100644 --- a/app/schemas.py +++ b/app/schemas.py @@ -1,5 +1,6 @@ import datetime -from typing import Optional +from copy import deepcopy +from typing import Any, Optional, Tuple, Type from fastapi_filter.contrib.sqlalchemy import Filter from openfoodfacts import Flavor @@ -9,17 +10,43 @@ BaseModel, ConfigDict, Field, + create_model, field_validator, model_validator, ) +from pydantic.fields import FieldInfo from app.enums import CurrencyEnum, LocationOSMEnum, PricePerEnum, ProofTypeEnum from app.models import Location, Price, Product, Proof, User -# Session -# ------------------------------------------------------------------------------ + +def partial_model(model: Type[BaseModel]): + """ + Custom decorator to set all fields of a model as optional. + https://stackoverflow.com/a/76560886/4293684 + """ + + def make_field_optional( + field: FieldInfo, default: Any = None + ) -> Tuple[Any, FieldInfo]: + new = deepcopy(field) + new.default = default + new.annotation = Optional[field.annotation] # type: ignore + return new.annotation, new + + return create_model( + f"Partial{model.__name__}", + __base__=model, + __module__=model.__module__, + **{ + field_name: make_field_optional(field_info) + for field_name, field_info in model.model_fields.items() + }, + ) +# Session +# ------------------------------------------------------------------------------ class SessionBase(BaseModel): model_config = ConfigDict(from_attributes=True) @@ -233,9 +260,66 @@ class Config: # Price # ------------------------------------------------------------------------------ -class PriceCreate(BaseModel): - model_config = ConfigDict(from_attributes=True, arbitrary_types_allowed=True) +class PriceBase(BaseModel): + model_config = ConfigDict( + from_attributes=True, arbitrary_types_allowed=True, extra="forbid" + ) + + price: float = Field( + gt=0, + description="price of the product, without its currency, taxes included.", + examples=[1.99], + ) + price_is_discounted: bool = Field( + default=False, + description="true if the price is discounted.", + examples=[True], + ) + price_without_discount: float | None = Field( + default=None, + description="price of the product without discount, without its currency, taxes included. " + "If the product is not discounted, this field must be null. ", + examples=[2.99], + ) + price_per: PricePerEnum | None = Field( + default=PricePerEnum.KILOGRAM, + description="""if the price is about a barcode-less product + (if `category_tag` is provided), this field must be set to `KILOGRAM` + or `UNIT` (KILOGRAM by default). + This field is set to null and ignored if `product_code` is provided. + """, + examples=["KILOGRAM", "UNIT"], + ) + currency: CurrencyEnum = Field( + description="currency of the price, as a string. " + "The currency must be a valid currency code. " + "See https://en.wikipedia.org/wiki/ISO_4217 for a list of valid currency codes.", + examples=["EUR", "USD"], + ) + date: datetime.date = Field( + description="date when the product was bought.", examples=["2024-01-01"] + ) + @model_validator(mode="after") + def check_price_discount(self): # type: ignore + """ + Check that: + - `price_is_discounted` is true if `price_without_discount` is passed + - `price_without_discount` is greater than `price` + """ + if self.price_without_discount is not None: + if not self.price_is_discounted: + raise ValueError( + "`price_is_discounted` must be true if `price_without_discount` is filled" + ) + if self.price_without_discount <= self.price: + raise ValueError( + "`price_without_discount` must be greater than `price`" + ) + return self + + +class PriceCreate(PriceBase): product_code: str | None = Field( default=None, min_length=1, @@ -291,37 +375,6 @@ class PriceCreate(BaseModel): If one of the origins is not valid, the price will be rejected.""", examples=[["en:france"], ["en:california"]], ) - price: float = Field( - gt=0, - description="price of the product, without its currency, taxes included.", - examples=[1.99], - ) - price_is_discounted: bool = Field( - default=False, - description="true if the price is discounted.", - examples=[True], - ) - price_without_discount: float | None = Field( - default=None, - description="price of the product without discount, without its currency, taxes included. " - "If the product is not discounted, this field must be null. ", - examples=[2.99], - ) - price_per: PricePerEnum | None = Field( - default=PricePerEnum.KILOGRAM, - description="""if the price is about a barcode-less product - (if `category_tag` is provided), this field must be set to `KILOGRAM` - or `UNIT` (KILOGRAM by default). - This field is set to null and ignored if `product_code` is provided. - """, - examples=["KILOGRAM", "UNIT"], - ) - currency: CurrencyEnum = Field( - description="currency of the price, as a string. " - "The currency must be a valid currency code. " - "See https://en.wikipedia.org/wiki/ISO_4217 for a list of valid currency codes.", - examples=["EUR", "USD"], - ) location_osm_id: int = Field( gt=0, description="ID of the location in OpenStreetMap: the store where the product was bought.", @@ -333,9 +386,6 @@ class PriceCreate(BaseModel): "information about the store using the ID.", examples=["NODE", "WAY", "RELATION"], ) - date: datetime.date = Field( - description="date when the product was bought.", examples=["2024-01-01"] - ) proof_id: int | None = Field( default=None, description="ID of the proof, if any. The proof is a file (receipt or price tag image) " @@ -345,15 +395,6 @@ class PriceCreate(BaseModel): examples=[15], ) - -class PriceCreateWithValidation(PriceCreate): - """A version of `PriceCreate` with taxonomy validations. - - These validations are not done in the `PriceCreate` model because they - they are time-consuming and only necessary when creating a price from - the API. - """ - @field_validator("labels_tags") def labels_tags_is_valid(cls, v: list[str] | None) -> list[str] | None: if v is not None: @@ -421,35 +462,10 @@ def set_price_per_to_null_if_barcode(self): # type: ignore self.price_per = None return self - @model_validator(mode="after") - def check_price_discount(self): # type: ignore - """ - Check that: - - `price_is_discounted` is true if `price_without_discount` is passed - - `price_without_discount` is greater than `price` - """ - if self.price_without_discount is not None: - if not self.price_is_discounted: - raise ValueError( - "`price_is_discounted` must be true if `price_without_discount` is filled" - ) - if self.price_without_discount <= self.price: - raise ValueError( - "`price_without_discount` must be greater than `price`" - ) - return self - -class PriceBasicUpdatableFields(BaseModel): - price: float | None = None - price_is_discounted: bool | None = None - price_without_discount: float | None = None - price_per: PricePerEnum | None = None - currency: CurrencyEnum | None = None - date: datetime.date | None = None - - class Config: - extra = "forbid" +@partial_model +class PriceUpdate(PriceBase): + pass class PriceFull(PriceCreate): diff --git a/tests/integration/test_api.py b/tests/integration/test_api.py index 080b7f4d..27a8571d 100644 --- a/tests/integration/test_api.py +++ b/tests/integration/test_api.py @@ -628,7 +628,7 @@ def test_get_prices_filters(db_session, user_session: SessionModel, clean_prices user_session.user, ) crud.create_price( - db_session, PRICE_1.model_copy(update={"price": 5.10}), user_session.user + db_session, PRICE_1.model_copy(update={"price": 4.1}), user_session.user ) crud.create_price( db_session, @@ -653,8 +653,8 @@ def test_get_prices_filters(db_session, user_session: SessionModel, clean_prices response = client.get("/api/v1/prices?category_tag=en:tomatoes") assert response.status_code == 200 assert len(response.json()["items"]) == 1 - # 1 price with price > 5 - response = client.get("/api/v1/prices?price__gt=5") + # 1 price with price > 4 + response = client.get("/api/v1/prices?price__gt=4") assert response.status_code == 200 assert len(response.json()["items"]) == 1 # 1 price with currency USD @@ -694,7 +694,7 @@ def test_update_price( # create price db_price = crud.create_price(db_session, PRICE_1, user_session.user) - new_price = 5.5 + new_price = 4.5 PRICE_UPDATE_PARTIAL = {"price": new_price} # without authentication response = client.patch(f"/api/v1/prices/{db_price.id}") @@ -741,14 +741,21 @@ def test_update_price( response.json()["price_is_discounted"] != PRICE_1.price_is_discounted ) # False assert response.json()["price_without_discount"] is None - # with authentication and price owner but extra fields - PRICE_UPDATE_PARTIAL_WRONG = {**PRICE_UPDATE_PARTIAL, "proof_id": 1} - response = client.patch( - f"/api/v1/prices/{db_price.id}", - headers={"Authorization": f"Bearer {user_session.token}"}, - json=jsonable_encoder(PRICE_UPDATE_PARTIAL_WRONG), - ) - assert response.status_code == 422 + # with authentication and price owner but validation error + PRICE_UPDATE_PARTIAL_WRONG_LIST = [ + {**PRICE_UPDATE_PARTIAL, "proof_id": 1}, # extra field + {**PRICE_UPDATE_PARTIAL, "price": -1}, # price negative + {"price_without_discount": 3, "price_is_discounted": False}, # incoherence + {"price_without_discount": 3} # price_without_discount < price + # {**PRICE_UPDATE_PARTIAL, "price_per": "UNIT"} + ] + for PRICE_UPDATE_PARTIAL_WRONG in PRICE_UPDATE_PARTIAL_WRONG_LIST: + response = client.patch( + f"/api/v1/prices/{db_price.id}", + headers={"Authorization": f"Bearer {user_session.token}"}, + json=jsonable_encoder(PRICE_UPDATE_PARTIAL_WRONG), + ) + assert response.status_code == 422 def test_update_price_moderator( @@ -757,7 +764,7 @@ def test_update_price_moderator( # create price db_price = crud.create_price(db_session, PRICE_1, user_session.user) - new_price = 5.5 + new_price = 4.5 PRICE_UPDATE_PARTIAL = {"price": new_price} # user_1 is moderator, not owner diff --git a/tests/unit/test_schema.py b/tests/unit/test_schema.py index 0b57a808..b9e9c600 100644 --- a/tests/unit/test_schema.py +++ b/tests/unit/test_schema.py @@ -3,12 +3,12 @@ import pydantic import pytest -from app.schemas import CurrencyEnum, LocationOSMEnum, PriceCreateWithValidation +from app.schemas import CurrencyEnum, LocationOSMEnum, PriceCreate class TestPriceCreate: def test_simple_price_with_barcode(self): - price = PriceCreateWithValidation( + price = PriceCreate( product_code="5414661000456", location_osm_id=123, location_osm_type=LocationOSMEnum.NODE, @@ -24,7 +24,7 @@ def test_simple_price_with_barcode(self): assert price.date == datetime.date.fromisoformat("2021-01-01") def test_simple_price_with_category(self): - price = PriceCreateWithValidation( + price = PriceCreate( category_tag="en:Fresh-apricots", labels_tags=["en:Organic", "fr:AB-agriculture-biologique"], origins_tags=["en:California", "en:Sweden"], @@ -40,7 +40,7 @@ def test_simple_price_with_category(self): def test_simple_price_with_invalid_taxonomized_values(self): with pytest.raises(pydantic.ValidationError, match="Invalid category tag"): - PriceCreateWithValidation( + PriceCreate( category_tag="en:unknown-category", location_osm_id=123, location_osm_type=LocationOSMEnum.NODE, @@ -50,7 +50,7 @@ def test_simple_price_with_invalid_taxonomized_values(self): ) with pytest.raises(pydantic.ValidationError, match="Invalid label tag"): - PriceCreateWithValidation( + PriceCreate( category_tag="en:carrots", labels_tags=["en:invalid"], location_osm_id=123, @@ -61,7 +61,7 @@ def test_simple_price_with_invalid_taxonomized_values(self): ) with pytest.raises(pydantic.ValidationError, match="Invalid origin tag"): - PriceCreateWithValidation( + PriceCreate( category_tag="en:carrots", origins_tags=["en:invalid"], location_osm_id=123, @@ -76,7 +76,7 @@ def test_simple_price_with_product_code_and_labels_tags_raise(self): pydantic.ValidationError, match="`labels_tags` can only be set for products without barcode", ): - PriceCreateWithValidation( + PriceCreate( product_code="5414661000456", labels_tags=["en:Organic", "fr:AB-agriculture-biologique"], location_osm_id=123, @@ -91,7 +91,7 @@ def test_price_discount_raise(self): pydantic.ValidationError, match="`price_is_discounted` must be true if `price_without_discount` is filled", ): - PriceCreateWithValidation( + PriceCreate( product_code="5414661000456", location_osm_id=123, location_osm_type=LocationOSMEnum.NODE, @@ -105,7 +105,7 @@ def test_price_discount_raise(self): pydantic.ValidationError, match="`price_without_discount` must be greater than `price`", ): - PriceCreateWithValidation( + PriceCreate( product_code="5414661000456", location_osm_id=123, location_osm_type=LocationOSMEnum.NODE,