diff --git a/fastapi_crudrouter/core/_base.py b/fastapi_crudrouter/core/_base.py index e45d33fe..4c662885 100644 --- a/fastapi_crudrouter/core/_base.py +++ b/fastapi_crudrouter/core/_base.py @@ -5,7 +5,7 @@ from fastapi.types import DecoratedCallable from ._types import T, DEPENDENCIES -from ._utils import pagination_factory, schema_factory +from ._utils import pagination_factory, schema_factory, make_optional NOT_FOUND = HTTPException(404, "Item not found") @@ -28,6 +28,7 @@ def __init__( get_one_route: Union[bool, DEPENDENCIES] = True, create_route: Union[bool, DEPENDENCIES] = True, update_route: Union[bool, DEPENDENCIES] = True, + patch_route: Union[bool, DEPENDENCIES] = True, delete_one_route: Union[bool, DEPENDENCIES] = True, delete_all_route: Union[bool, DEPENDENCIES] = True, **kwargs: Any, @@ -46,6 +47,13 @@ def __init__( if update_schema else schema_factory(self.schema, pk_field_name=self._pk, name="Update") ) + self.patch_schema = ( + make_optional(update_schema) + if update_schema + else make_optional( + schema_factory(self.schema, pk_field_name=self._pk, name="Patch") + ) + ) prefix = str(prefix if prefix else self.schema.__name__).lower() prefix = self._base_path + prefix.strip("/") @@ -105,6 +113,17 @@ def __init__( error_responses=[NOT_FOUND], ) + if patch_route: + self._add_api_route( + "/{item_id}", + self._patch(), + methods=["PATCH"], + response_model=self.schema, + summary="Partiall Update One", + dependencies=patch_route, + error_responses=[NOT_FOUND], + ) + if delete_one_route: self._add_api_route( "/{item_id}", @@ -161,6 +180,12 @@ def put( self.remove_api_route(path, ["PUT"]) return super().put(path, *args, **kwargs) + def patch( + self, path: str, *args: Any, **kwargs: Any + ) -> Callable[[DecoratedCallable], DecoratedCallable]: + self.remove_api_route(path, ["PATCH"]) + return super().put(path, *args, **kwargs) + def delete( self, path: str, *args: Any, **kwargs: Any ) -> Callable[[DecoratedCallable], DecoratedCallable]: @@ -193,6 +218,10 @@ def _create(self, *args: Any, **kwargs: Any) -> Callable[..., Any]: def _update(self, *args: Any, **kwargs: Any) -> Callable[..., Any]: raise NotImplementedError + @abstractmethod + def _patch(self, *args: Any, **kwargs: Any) -> Callable[..., Any]: + raise NotImplementedError + @abstractmethod def _delete_one(self, *args: Any, **kwargs: Any) -> Callable[..., Any]: raise NotImplementedError @@ -206,4 +235,12 @@ def _raise(self, e: Exception, status_code: int = 422) -> HTTPException: @staticmethod def get_routes() -> List[str]: - return ["get_all", "create", "delete_all", "get_one", "update", "delete_one"] + return [ + "get_all", + "create", + "delete_all", + "get_one", + "update", + "patch", + "delete_one", + ] diff --git a/fastapi_crudrouter/core/_utils.py b/fastapi_crudrouter/core/_utils.py index ef3562e4..e3c87805 100644 --- a/fastapi_crudrouter/core/_utils.py +++ b/fastapi_crudrouter/core/_utils.py @@ -2,7 +2,6 @@ from fastapi import Depends, HTTPException from pydantic import create_model - from ._types import T, PAGINATION, PYDANTIC_SCHEMA @@ -12,6 +11,20 @@ def __init__(self, *args, **kwargs) -> None: # type: ignore self.__dict__ = self +# TODO this lets the patch request come with arbitrary number of fields +# Need to validate the fields that are present only in the schema +def make_optional(baseclass:Type[T]) -> Type[T]: + # Extracts the fields and validators from the baseclass and make fields optional + fields = baseclass.__fields__ + validators = {"__validators__": baseclass.__validators__} + optional_fields = { + key: (Optional[item.type_], None) for key, item in fields.items() + } + return create_model( + f"{baseclass.__name__}Optional", **optional_fields, __validators__=validators + ) + + def get_pk_type(schema: Type[PYDANTIC_SCHEMA], pk_field: str) -> Any: try: return schema.__fields__[pk_field].type_ diff --git a/fastapi_crudrouter/core/databases.py b/fastapi_crudrouter/core/databases.py index 7ea3c711..dc7cc628 100644 --- a/fastapi_crudrouter/core/databases.py +++ b/fastapi_crudrouter/core/databases.py @@ -54,6 +54,7 @@ def __init__( get_one_route: Union[bool, DEPENDENCIES] = True, create_route: Union[bool, DEPENDENCIES] = True, update_route: Union[bool, DEPENDENCIES] = True, + patch_route: Union[bool, DEPENDENCIES] = True, delete_one_route: Union[bool, DEPENDENCIES] = True, delete_all_route: Union[bool, DEPENDENCIES] = True, **kwargs: Any @@ -79,6 +80,7 @@ def __init__( get_one_route=get_one_route, create_route=create_route, update_route=update_route, + patch_route=patch_route, delete_one_route=delete_one_route, delete_all_route=delete_all_route, **kwargs @@ -140,6 +142,23 @@ async def route( return route + def _patch(self, *args: Any, **kwargs: Any) -> CALLABLE: + async def route( + item_id: self._pk_type, schema: self.patch_schema # type: ignore + ) -> Model: + query = self.table.update().where(self._pk_col == item_id) + + try: + await self.db.fetch_one( + query=query, + values=schema.dict(exclude={self._pk}, exclude_unset=True), + ) + return await self._get_one()(item_id) + except Exception as e: + raise NOT_FOUND from e + + return route + def _delete_all(self, *args: Any, **kwargs: Any) -> CALLABLE_LIST: async def route() -> List[Model]: query = self.table.delete() diff --git a/fastapi_crudrouter/core/gino_starlette.py b/fastapi_crudrouter/core/gino_starlette.py index d07d893e..e661d880 100644 --- a/fastapi_crudrouter/core/gino_starlette.py +++ b/fastapi_crudrouter/core/gino_starlette.py @@ -39,6 +39,7 @@ def __init__( get_one_route: Union[bool, DEPENDENCIES] = True, create_route: Union[bool, DEPENDENCIES] = True, update_route: Union[bool, DEPENDENCIES] = True, + patch_route: Union[bool, DEPENDENCIES] = True, delete_one_route: Union[bool, DEPENDENCIES] = True, delete_all_route: Union[bool, DEPENDENCIES] = True, **kwargs: Any @@ -61,6 +62,7 @@ def __init__( get_one_route=get_one_route, create_route=create_route, update_route=update_route, + patch_route=patch_route, delete_one_route=delete_one_route, delete_all_route=delete_all_route, **kwargs @@ -120,6 +122,23 @@ async def route( return route + def _patch(self, *args: Any, **kwargs: Any) -> CALLABLE: + async def route( + item_id: self._pk_type, # type: ignore + model: self.patch_schema, # type: ignore + ) -> Model: + try: + db_model: Model = await self._get_one()(item_id) + async with self.db.transaction(): + model = model.dict(exclude={self._pk}, exclude_unset=True) + await db_model.update(**model).apply() + db_model: Model = await self._get_one()(item_id) + return db_model + except (IntegrityError, UniqueViolationError) as e: + self._raise(e) + + return route + def _delete_all(self, *args: Any, **kwargs: Any) -> CALLABLE_LIST: async def route() -> List[Model]: await self.db_model.delete.gino.status() diff --git a/fastapi_crudrouter/core/mem.py b/fastapi_crudrouter/core/mem.py index d4e13c11..8239a13c 100644 --- a/fastapi_crudrouter/core/mem.py +++ b/fastapi_crudrouter/core/mem.py @@ -20,6 +20,7 @@ def __init__( get_one_route: Union[bool, DEPENDENCIES] = True, create_route: Union[bool, DEPENDENCIES] = True, update_route: Union[bool, DEPENDENCIES] = True, + patch_route: Union[bool, DEPENDENCIES] = True, delete_one_route: Union[bool, DEPENDENCIES] = True, delete_all_route: Union[bool, DEPENDENCIES] = True, **kwargs: Any @@ -35,6 +36,7 @@ def __init__( get_one_route=get_one_route, create_route=create_route, update_route=update_route, + patch_route=patch_route, delete_one_route=delete_one_route, delete_all_route=delete_all_route, **kwargs @@ -89,6 +91,20 @@ def route(item_id: int, model: self.update_schema) -> SCHEMA: # type: ignore return route + def _patch(self, *args: Any, **kwargs: Any) -> CALLABLE: + def route(item_id: int, model: self.patch_schema) -> SCHEMA: # type: ignore + for ind, model_ in enumerate(self.models): + if model_.id == item_id: # type: ignore + stored_item = model_.dict() + updated_item = model.dict(exclude_unset=True) + stored_item.update(updated_item) + self.models[ind] = self.schema(**stored_item) + return self.models[ind] + + raise NOT_FOUND + + return route + def _delete_all(self, *args: Any, **kwargs: Any) -> CALLABLE_LIST: def route() -> List[SCHEMA]: self.models = [] diff --git a/fastapi_crudrouter/core/ormar.py b/fastapi_crudrouter/core/ormar.py index 99952600..9679ace4 100644 --- a/fastapi_crudrouter/core/ormar.py +++ b/fastapi_crudrouter/core/ormar.py @@ -40,6 +40,7 @@ def __init__( get_one_route: Union[bool, DEPENDENCIES] = True, create_route: Union[bool, DEPENDENCIES] = True, update_route: Union[bool, DEPENDENCIES] = True, + patch_route: Union[bool, DEPENDENCIES] = True, delete_one_route: Union[bool, DEPENDENCIES] = True, delete_all_route: Union[bool, DEPENDENCIES] = True, **kwargs: Any @@ -60,6 +61,7 @@ def __init__( get_one_route=get_one_route, create_route=create_route, update_route=update_route, + patch_route=patch_route, delete_one_route=delete_one_route, delete_all_route=delete_all_route, **kwargs @@ -120,6 +122,22 @@ async def route( return route + def _patch(self, *args: Any, **kwargs: Any) -> CALLABLE: + async def route( + item_id: self._pk_type, # type: ignore + model: self.patch_schema, # type: ignore + ) -> Model: + filter_ = {self._pk: item_id} + try: + await self.schema.objects.filter(_exclude=False, **filter_).update( + **model.dict(exclude_unset=True) + ) + except self._INTEGRITY_ERROR as e: + self._raise(e) + return await self._get_one()(item_id) + + return route + def _delete_all(self, *args: Any, **kwargs: Any) -> CALLABLE_LIST: async def route() -> List[Optional[Model]]: await self.schema.objects.delete(each=True) diff --git a/fastapi_crudrouter/core/sqlalchemy.py b/fastapi_crudrouter/core/sqlalchemy.py index 58270f34..83d964b3 100644 --- a/fastapi_crudrouter/core/sqlalchemy.py +++ b/fastapi_crudrouter/core/sqlalchemy.py @@ -37,6 +37,7 @@ def __init__( get_one_route: Union[bool, DEPENDENCIES] = True, create_route: Union[bool, DEPENDENCIES] = True, update_route: Union[bool, DEPENDENCIES] = True, + patch_route: Union[bool, DEPENDENCIES] = True, delete_one_route: Union[bool, DEPENDENCIES] = True, delete_all_route: Union[bool, DEPENDENCIES] = True, **kwargs: Any @@ -61,6 +62,7 @@ def __init__( get_one_route=get_one_route, create_route=create_route, update_route=update_route, + patch_route=patch_route, delete_one_route=delete_one_route, delete_all_route=delete_all_route, **kwargs @@ -137,6 +139,31 @@ def route( return route + def _patch(self, *args: Any, **kwargs: Any) -> CALLABLE: + def route( + item_id: self._pk_type, # type: ignore + model: self.patch_schema, # type: ignore + db: Session = Depends(self.db_func), + ) -> Model: + try: + db_model: Model = self._get_one()(item_id, db) + + for key, value in model.dict( + exclude={self._pk}, exclude_unset=True + ).items(): + if hasattr(db_model, key): + setattr(db_model, key, value) + + db.commit() + db.refresh(db_model) + + return db_model + except IntegrityError as e: + db.rollback() + self._raise(e) + + return route + def _delete_all(self, *args: Any, **kwargs: Any) -> CALLABLE_LIST: def route(db: Session = Depends(self.db_func)) -> List[Model]: db.query(self.db_model).delete() diff --git a/fastapi_crudrouter/core/tortoise.py b/fastapi_crudrouter/core/tortoise.py index 52972a48..52fc3ad4 100644 --- a/fastapi_crudrouter/core/tortoise.py +++ b/fastapi_crudrouter/core/tortoise.py @@ -30,6 +30,7 @@ def __init__( get_one_route: Union[bool, DEPENDENCIES] = True, create_route: Union[bool, DEPENDENCIES] = True, update_route: Union[bool, DEPENDENCIES] = True, + patch_route: Union[bool, DEPENDENCIES] = True, delete_one_route: Union[bool, DEPENDENCIES] = True, delete_all_route: Union[bool, DEPENDENCIES] = True, **kwargs: Any @@ -52,6 +53,7 @@ def __init__( get_one_route=get_one_route, create_route=create_route, update_route=update_route, + patch_route=patch_route, delete_one_route=delete_one_route, delete_all_route=delete_all_route, **kwargs @@ -98,6 +100,17 @@ async def route( return route + def _patch(self, *args: Any, **kwargs: Any) -> CALLABLE: + async def route( + item_id: int, model: self.patch_schema # type: ignore + ) -> Model: + await self.db_model.filter(id=item_id).update( + **model.dict(exclude_unset=True) + ) + return await self._get_one()(item_id) + + return route + def _delete_all(self, *args: Any, **kwargs: Any) -> CALLABLE_LIST: async def route() -> List[Model]: await self.db_model.all().delete() diff --git a/tests/test_openapi_schema.py b/tests/test_openapi_schema.py index 61faf752..4f2c8cbc 100644 --- a/tests/test_openapi_schema.py +++ b/tests/test_openapi_schema.py @@ -24,9 +24,8 @@ def test_schema_tags(self, client): paths = schema["paths"] assert len(paths) == len(PATH_TAGS) + print(paths.items()) for path, method in paths.items(): - assert len(method) == 3 - for m in method: assert method[m]["tags"] == PATH_TAGS[path] @@ -41,7 +40,7 @@ def test_response_types(self, client, path): assert "422" in paths[path]["post"]["responses"] item_path = path + "/{item_id}" - for method in ["get", "put", "delete"]: + for method in ["get", "put", "patch", "delete"]: assert "200" in paths[item_path][method]["responses"] assert "404" in paths[item_path][method]["responses"] assert "422" in paths[item_path][method]["responses"] diff --git a/tests/test_router.py b/tests/test_router.py index 582edec0..031a8061 100644 --- a/tests/test_router.py +++ b/tests/test_router.py @@ -70,6 +70,31 @@ def test_update(client, url: str = URL, model: Dict = None, id_key: str = "id"): assert not compare_dict(res.json(), model, exclude=[id_key]) +def test_patch(client, url: str = URL, model: Dict = None, id_key: str = "id"): + test_get(client, url, expected_length=0) + + model = model or basic_potato + res = client.post(url, json=model) + data = res.json() + assert res.status_code == 200 + + test_get(client, url, expected_length=1) + + tuber = {} + tuber["color"] = "yellow" + resp_tuber = {k: v for k, v in model.items()} + resp_tuber["color"] = "yellow" + res = client.patch(f"{url}/{data[id_key]}", json=tuber) + assert res.status_code == 200 + assert compare_dict(res.json(), resp_tuber, exclude=[id_key]) + assert not compare_dict(res.json(), model, exclude=[id_key]) + + res = client.get(f"{url}/{data[id_key]}") + assert res.status_code == 200 + assert compare_dict(res.json(), resp_tuber, exclude=[id_key]) + assert not compare_dict(res.json(), model, exclude=[id_key]) + + def test_delete_one(client, url: str = URL, model: Dict = None, id_key: str = "id"): model = model or basic_potato res = client.post(url, json=model)