Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support default_factory on primary keys of schemas for create routes #166

Open
wants to merge 11 commits into
base: master
Choose a base branch
from
19 changes: 17 additions & 2 deletions fastapi_crudrouter/core/_utils.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Optional, Type, Any
from typing import Optional, Type, Any, Tuple

from fastapi import Depends, HTTPException
from pydantic import create_model
Expand All @@ -12,13 +12,28 @@ def __init__(self, *args, **kwargs) -> None: # type: ignore
self.__dict__ = self


def get_pk_type(schema: Type[PYDANTIC_SCHEMA], pk_field: str) -> Any:
def get_pk_type(schema: Type[PYDANTIC_SCHEMA], pk_field: str = "id") -> Any:
try:
return schema.__fields__[pk_field].type_
except KeyError:
return int


def create_schema_default_factory(
schema_cls: Type[T], create_schema_instance: T, pk_field_name: str = "id"
) -> Tuple[T, bool]:
"""
Is used to check for default_factory for the pk on a Schema,
passing the CreateSchema values into the Schema if a
default_factory on the pk exists
"""

if callable(schema_cls.__fields__[pk_field_name].default_factory):
return schema_cls(**create_schema_instance.dict()), True
else:
return create_schema_instance, False


def schema_factory(
schema_cls: Type[T], pk_field_name: str = "id", name: str = "Create"
) -> Type[T]:
Expand Down
7 changes: 6 additions & 1 deletion fastapi_crudrouter/core/databases.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@

from . import CRUDGenerator, NOT_FOUND
from ._types import PAGINATION, PYDANTIC_SCHEMA, DEPENDENCIES
from ._utils import AttrDict, get_pk_type
from ._utils import AttrDict, get_pk_type, create_schema_default_factory

try:
from sqlalchemy.sql.schema import Table
Expand Down Expand Up @@ -111,6 +111,11 @@ def _create(self, *args: Any, **kwargs: Any) -> CALLABLE:
async def route(
schema: self.create_schema, # type: ignore
) -> Model:
schema, _ = create_schema_default_factory(
schema_cls=self.schema,
create_schema_instance=schema,
pk_field_name=self._pk,
)
query = self.table.insert()

try:
Expand Down
7 changes: 7 additions & 0 deletions fastapi_crudrouter/core/gino_starlette.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from . import NOT_FOUND, CRUDGenerator, _utils
from ._types import DEPENDENCIES, PAGINATION
from ._types import PYDANTIC_SCHEMA as SCHEMA
from ._utils import create_schema_default_factory

try:
from asyncpg.exceptions import UniqueViolationError
Expand Down Expand Up @@ -94,6 +95,12 @@ def _create(self, *args: Any, **kwargs: Any) -> CALLABLE:
async def route(
model: self.create_schema, # type: ignore
) -> Model:
model, _ = create_schema_default_factory(
schema_cls=self.schema,
create_schema_instance=model,
pk_field_name=self._pk,
)

try:
async with self.db.transaction():
db_model: Model = await self.db_model.create(**model.dict())
Expand Down
23 changes: 19 additions & 4 deletions fastapi_crudrouter/core/mem.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
from typing import Any, Callable, List, Type, cast, Optional, Union

from fastapi import HTTPException

from . import CRUDGenerator, NOT_FOUND
from ._types import DEPENDENCIES, PAGINATION, PYDANTIC_SCHEMA as SCHEMA
from ._utils import get_pk_type, create_schema_default_factory

CALLABLE = Callable[..., SCHEMA]
CALLABLE_LIST = Callable[..., List[SCHEMA]]
Expand All @@ -24,6 +27,8 @@ def __init__(
delete_all_route: Union[bool, DEPENDENCIES] = True,
**kwargs: Any
) -> None:
self._pk_type: type = get_pk_type(schema)

super().__init__(
schema=schema,
create_schema=create_schema,
Expand Down Expand Up @@ -57,7 +62,7 @@ def route(pagination: PAGINATION = self.pagination) -> List[SCHEMA]:
return route

def _get_one(self, *args: Any, **kwargs: Any) -> CALLABLE:
def route(item_id: int) -> SCHEMA:
def route(item_id: self._pk_type) -> SCHEMA:
for model in self.models:
if model.id == item_id: # type: ignore
return model
Expand All @@ -68,16 +73,26 @@ def route(item_id: int) -> SCHEMA:

def _create(self, *args: Any, **kwargs: Any) -> CALLABLE:
def route(model: self.create_schema) -> SCHEMA: # type: ignore
model, using_default_factory = create_schema_default_factory(
schema_cls=self.schema,
create_schema_instance=model,
pk_field_name=self._pk,
)
model_dict = model.dict()
model_dict["id"] = self._get_next_id()
if using_default_factory:
for _model in self.models:
if _model.id == model.id: # type: ignore
raise HTTPException(422, "Key already exists") from None
else:
model_dict["id"] = self._get_next_id()
ready_model = self.schema(**model_dict)
self.models.append(ready_model)
return ready_model

return route

def _update(self, *args: Any, **kwargs: Any) -> CALLABLE:
def route(item_id: int, model: self.update_schema) -> SCHEMA: # type: ignore
def route(item_id: self._pk_type, model: self.update_schema) -> SCHEMA: # type: ignore
for ind, model_ in enumerate(self.models):
if model_.id == item_id: # type: ignore
self.models[ind] = self.schema(
Expand All @@ -97,7 +112,7 @@ def route() -> List[SCHEMA]:
return route

def _delete_one(self, *args: Any, **kwargs: Any) -> CALLABLE:
def route(item_id: int) -> SCHEMA:
def route(item_id: self._pk_type) -> SCHEMA:
for ind, model in enumerate(self.models):
if model.id == item_id: # type: ignore
del self.models[ind]
Expand Down
10 changes: 10 additions & 0 deletions fastapi_crudrouter/core/ormar.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@

from . import CRUDGenerator, NOT_FOUND, _utils
from ._types import DEPENDENCIES, PAGINATION
from ._utils import create_schema_default_factory

try:
from ormar import Model, NoMatch
Expand All @@ -33,6 +34,7 @@ def __init__(
schema: Type[Model],
create_schema: Optional[Type[Model]] = None,
update_schema: Optional[Type[Model]] = None,
default_factory_schema: Optional[Type[Model]] = None,
prefix: Optional[str] = None,
tags: Optional[List[str]] = None,
paginate: Optional[int] = None,
Expand All @@ -48,6 +50,9 @@ def __init__(

self._pk: str = schema.Meta.pkname
self._pk_type: type = _utils.get_pk_type(schema, self._pk)
self.default_factory_schema = (
default_factory_schema if default_factory_schema else schema
)

super().__init__(
schema=schema,
Expand Down Expand Up @@ -94,6 +99,11 @@ async def route(item_id: self._pk_type) -> Model: # type: ignore

def _create(self, *args: Any, **kwargs: Any) -> CALLABLE:
async def route(model: self.create_schema) -> Model: # type: ignore
model, _ = create_schema_default_factory(
schema_cls=self.default_factory_schema,
create_schema_instance=model,
pk_field_name=self._pk,
)
model_dict = model.dict()
if self.schema.Meta.model_fields[self._pk].autoincrement:
model_dict.pop(self._pk, None)
Expand Down
7 changes: 7 additions & 0 deletions fastapi_crudrouter/core/sqlalchemy.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

from . import CRUDGenerator, NOT_FOUND, _utils
from ._types import DEPENDENCIES, PAGINATION, PYDANTIC_SCHEMA as SCHEMA
from ._utils import create_schema_default_factory

try:
from sqlalchemy.orm import Session
Expand Down Expand Up @@ -102,6 +103,12 @@ def route(
model: self.create_schema, # type: ignore
db: Session = Depends(self.db_func),
) -> Model:
model, _ = create_schema_default_factory(
schema_cls=self.schema,
create_schema_instance=model,
pk_field_name=self._pk,
)

try:
db_model: Model = self.db_model(**model.dict())
db.add(db_model)
Expand Down
13 changes: 10 additions & 3 deletions fastapi_crudrouter/core/tortoise.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

from . import CRUDGenerator, NOT_FOUND
from ._types import DEPENDENCIES, PAGINATION, PYDANTIC_SCHEMA as SCHEMA
from ._utils import get_pk_type, create_schema_default_factory

try:
from tortoise.models import Model
Expand Down Expand Up @@ -40,6 +41,7 @@ def __init__(

self.db_model = db_model
self._pk: str = db_model.describe()["pk_field"]["db_column"]
self._pk_type: type = get_pk_type(schema, self._pk)

super().__init__(
schema=schema,
Expand Down Expand Up @@ -68,7 +70,7 @@ async def route(pagination: PAGINATION = self.pagination) -> List[Model]:
return route

def _get_one(self, *args: Any, **kwargs: Any) -> CALLABLE:
async def route(item_id: int) -> Model:
async def route(item_id: self._pk_type) -> Model:
model = await self.db_model.filter(id=item_id).first()

if model:
Expand All @@ -80,6 +82,11 @@ async def route(item_id: int) -> Model:

def _create(self, *args: Any, **kwargs: Any) -> CALLABLE:
async def route(model: self.create_schema) -> Model: # type: ignore
model, _ = create_schema_default_factory(
schema_cls=self.schema,
create_schema_instance=model,
pk_field_name=self._pk,
)
db_model = self.db_model(**model.dict())
await db_model.save()

Expand All @@ -89,7 +96,7 @@ async def route(model: self.create_schema) -> Model: # type: ignore

def _update(self, *args: Any, **kwargs: Any) -> CALLABLE:
async def route(
item_id: int, model: self.update_schema # type: ignore
item_id: self._pk_type, model: self.update_schema # type: ignore
) -> Model:
await self.db_model.filter(id=item_id).update(
**model.dict(exclude_unset=True)
Expand All @@ -106,7 +113,7 @@ async def route() -> List[Model]:
return route

def _delete_one(self, *args: Any, **kwargs: Any) -> CALLABLE:
async def route(item_id: int) -> Model:
async def route(item_id: self._pk_type) -> Model:
model: Model = await self._get_one()(item_id)
await self.db_model.filter(id=item_id).delete()

Expand Down
14 changes: 13 additions & 1 deletion tests/__init__.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
from pydantic import BaseModel
from uuid import uuid4
from pydantic import BaseModel, Field

from .conf import config

PAGINATION_SIZE = 10
CUSTOM_TAGS = ["Tag1", "Tag2"]
POTATO_TAGS = ["Potato"]


class ORMModel(BaseModel):
Expand All @@ -24,6 +26,16 @@ class Potato(PotatoCreate, ORMModel):
pass


class DefaultFactoryPotatoCreate(BaseModel):
color: str
mass: float


class DefaultFactoryPotato(DefaultFactoryPotatoCreate, ORMModel):
id: str = Field(default_factory=lambda: str(uuid4()))
pass


class CustomPotato(PotatoCreate):
potato_id: int

Expand Down
17 changes: 17 additions & 0 deletions tests/implementations/databases_.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,9 @@
CustomPotato,
PAGINATION_SIZE,
Potato,
DefaultFactoryPotato,
PotatoType,
POTATO_TAGS,
CUSTOM_TAGS,
config,
)
Expand Down Expand Up @@ -47,6 +49,13 @@ def databases_implementation(db_uri: str):
Column("color", String),
Column("type", String),
)
defaultfactorypotatoes = Table(
"defaultfactorypotatoes",
metadata,
Column("id", String, primary_key=True),
Column("color", String),
Column("mass", Float),
)
carrots = Table(
"carrots",
metadata,
Expand Down Expand Up @@ -74,6 +83,14 @@ async def shutdown():
prefix="potato",
paginate=PAGINATION_SIZE,
),
dict(
database=database,
table=defaultfactorypotatoes,
schema=DefaultFactoryPotato,
prefix="defaultfactorypotato",
tags=POTATO_TAGS,
paginate=PAGINATION_SIZE,
),
dict(
database=database,
table=carrots,
Expand Down
16 changes: 16 additions & 0 deletions tests/implementations/gino_.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@
CarrotCreate,
CarrotUpdate,
CustomPotato,
DefaultFactoryPotato,
POTATO_TAGS,
Potato,
PotatoType,
config,
Expand Down Expand Up @@ -47,6 +49,12 @@ class PotatoModel(db.Model):
color = db.Column(db.String)
type = db.Column(db.String)

class DefaultFactoryPotatoModel(db.Model):
__tablename__ = "defaultfactorypotatoes"
id = db.Column(db.String, primary_key=True, index=True)
mass = db.Column(db.Float)
color = db.Column(db.String)

class CarrotModel(db.Model):
__tablename__ = "carrots"
id = db.Column(db.Integer, primary_key=True, index=True)
Expand All @@ -63,6 +71,14 @@ class CarrotModel(db.Model):
prefix="potato",
paginate=PAGINATION_SIZE,
),
dict(
schema=DefaultFactoryPotato,
db_model=DefaultFactoryPotatoModel,
db=db,
prefix="defaultfactorypotato",
tags=POTATO_TAGS,
paginate=PAGINATION_SIZE,
),
dict(
schema=Carrot,
db_model=CarrotModel,
Expand Down
16 changes: 14 additions & 2 deletions tests/implementations/memory.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,22 @@
from fastapi import FastAPI

from fastapi_crudrouter import MemoryCRUDRouter
from tests import Potato, Carrot, CarrotUpdate, PAGINATION_SIZE, CUSTOM_TAGS
from tests import (
Potato,
DefaultFactoryPotato,
Carrot,
CarrotUpdate,
PAGINATION_SIZE,
CUSTOM_TAGS,
POTATO_TAGS,
)


def memory_implementation(**kwargs):
app = FastAPI()
router_settings = [
dict(schema=Potato, paginate=PAGINATION_SIZE),
dict(schema=DefaultFactoryPotato, paginate=PAGINATION_SIZE, tags=POTATO_TAGS),
dict(schema=Carrot, update_schema=CarrotUpdate, tags=CUSTOM_TAGS),
]

Expand All @@ -17,4 +26,7 @@ def memory_implementation(**kwargs):
if __name__ == "__main__":
import uvicorn

uvicorn.run(memory_implementation(), port=5000)
app, route_type, routes = memory_implementation()
for route in routes:
app.include_router(route_type(**route))
uvicorn.run(app, port=5000)
Loading