Skip to content
Open
19 changes: 17 additions & 2 deletions fastapi_crudrouter/core/_utils.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Optional, Type, Any
from typing import Optional, Type, Any, Tuple

from fastapi import Depends, HTTPException
from pydantic import create_model
Expand All @@ -12,13 +12,28 @@ def __init__(self, *args, **kwargs) -> None: # type: ignore
self.__dict__ = self


def get_pk_type(schema: Type[PYDANTIC_SCHEMA], pk_field: str) -> Any:
def get_pk_type(schema: Type[PYDANTIC_SCHEMA], pk_field: str = "id") -> Any:
try:
return schema.__fields__[pk_field].type_
except KeyError:
return int


def create_schema_default_factory(
schema_cls: Type[T], create_schema_instance: T, pk_field_name: str = "id"
) -> Tuple[T, bool]:
"""
Is used to check for default_factory for the pk on a Schema,
passing the CreateSchema values into the Schema if a
default_factory on the pk exists
"""

if callable(schema_cls.__fields__[pk_field_name].default_factory):
return schema_cls(**create_schema_instance.dict()), True
else:
return create_schema_instance, False


def schema_factory(
schema_cls: Type[T], pk_field_name: str = "id", name: str = "Create"
) -> Type[T]:
Expand Down
7 changes: 6 additions & 1 deletion fastapi_crudrouter/core/databases.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@

from . import CRUDGenerator, NOT_FOUND
from ._types import PAGINATION, PYDANTIC_SCHEMA, DEPENDENCIES
from ._utils import AttrDict, get_pk_type
from ._utils import AttrDict, get_pk_type, create_schema_default_factory

try:
from sqlalchemy.sql.schema import Table
Expand Down Expand Up @@ -111,6 +111,11 @@ def _create(self, *args: Any, **kwargs: Any) -> CALLABLE:
async def route(
schema: self.create_schema, # type: ignore
) -> Model:
schema, _ = create_schema_default_factory(
schema_cls=self.schema,
create_schema_instance=schema,
pk_field_name=self._pk,
)
query = self.table.insert()

try:
Expand Down
7 changes: 7 additions & 0 deletions fastapi_crudrouter/core/gino_starlette.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from . import NOT_FOUND, CRUDGenerator, _utils
from ._types import DEPENDENCIES, PAGINATION
from ._types import PYDANTIC_SCHEMA as SCHEMA
from ._utils import create_schema_default_factory

try:
from asyncpg.exceptions import UniqueViolationError
Expand Down Expand Up @@ -94,6 +95,12 @@ def _create(self, *args: Any, **kwargs: Any) -> CALLABLE:
async def route(
model: self.create_schema, # type: ignore
) -> Model:
model, _ = create_schema_default_factory(
schema_cls=self.schema,
create_schema_instance=model,
pk_field_name=self._pk,
)

try:
async with self.db.transaction():
db_model: Model = await self.db_model.create(**model.dict())
Expand Down
23 changes: 19 additions & 4 deletions fastapi_crudrouter/core/mem.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
from typing import Any, Callable, List, Type, cast, Optional, Union

from fastapi import HTTPException

from . import CRUDGenerator, NOT_FOUND
from ._types import DEPENDENCIES, PAGINATION, PYDANTIC_SCHEMA as SCHEMA
from ._utils import get_pk_type, create_schema_default_factory

CALLABLE = Callable[..., SCHEMA]
CALLABLE_LIST = Callable[..., List[SCHEMA]]
Expand All @@ -24,6 +27,8 @@ def __init__(
delete_all_route: Union[bool, DEPENDENCIES] = True,
**kwargs: Any
) -> None:
self._pk_type: type = get_pk_type(schema)

super().__init__(
schema=schema,
create_schema=create_schema,
Expand Down Expand Up @@ -57,7 +62,7 @@ def route(pagination: PAGINATION = self.pagination) -> List[SCHEMA]:
return route

def _get_one(self, *args: Any, **kwargs: Any) -> CALLABLE:
def route(item_id: int) -> SCHEMA:
def route(item_id: self._pk_type) -> SCHEMA:
for model in self.models:
if model.id == item_id: # type: ignore
return model
Expand All @@ -68,16 +73,26 @@ def route(item_id: int) -> SCHEMA:

def _create(self, *args: Any, **kwargs: Any) -> CALLABLE:
def route(model: self.create_schema) -> SCHEMA: # type: ignore
model, using_default_factory = create_schema_default_factory(
schema_cls=self.schema,
create_schema_instance=model,
pk_field_name=self._pk,
)
model_dict = model.dict()
model_dict["id"] = self._get_next_id()
if using_default_factory:
for _model in self.models:
if _model.id == model.id: # type: ignore
raise HTTPException(422, "Key already exists") from None
else:
model_dict["id"] = self._get_next_id()
ready_model = self.schema(**model_dict)
self.models.append(ready_model)
return ready_model

return route

def _update(self, *args: Any, **kwargs: Any) -> CALLABLE:
def route(item_id: int, model: self.update_schema) -> SCHEMA: # type: ignore
def route(item_id: self._pk_type, model: self.update_schema) -> SCHEMA: # type: ignore
for ind, model_ in enumerate(self.models):
if model_.id == item_id: # type: ignore
self.models[ind] = self.schema(
Expand All @@ -97,7 +112,7 @@ def route() -> List[SCHEMA]:
return route

def _delete_one(self, *args: Any, **kwargs: Any) -> CALLABLE:
def route(item_id: int) -> SCHEMA:
def route(item_id: self._pk_type) -> SCHEMA:
for ind, model in enumerate(self.models):
if model.id == item_id: # type: ignore
del self.models[ind]
Expand Down
10 changes: 10 additions & 0 deletions fastapi_crudrouter/core/ormar.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@

from . import CRUDGenerator, NOT_FOUND, _utils
from ._types import DEPENDENCIES, PAGINATION
from ._utils import create_schema_default_factory

try:
from ormar import Model, NoMatch
Expand All @@ -33,6 +34,7 @@ def __init__(
schema: Type[Model],
create_schema: Optional[Type[Model]] = None,
update_schema: Optional[Type[Model]] = None,
default_factory_schema: Optional[Type[Model]] = None,
prefix: Optional[str] = None,
tags: Optional[List[str]] = None,
paginate: Optional[int] = None,
Expand All @@ -48,6 +50,9 @@ def __init__(

self._pk: str = schema.Meta.pkname
self._pk_type: type = _utils.get_pk_type(schema, self._pk)
self.default_factory_schema = (
default_factory_schema if default_factory_schema else schema
)

super().__init__(
schema=schema,
Expand Down Expand Up @@ -94,6 +99,11 @@ async def route(item_id: self._pk_type) -> Model: # type: ignore

def _create(self, *args: Any, **kwargs: Any) -> CALLABLE:
async def route(model: self.create_schema) -> Model: # type: ignore
model, _ = create_schema_default_factory(
schema_cls=self.default_factory_schema,
create_schema_instance=model,
pk_field_name=self._pk,
)
model_dict = model.dict()
if self.schema.Meta.model_fields[self._pk].autoincrement:
model_dict.pop(self._pk, None)
Expand Down
7 changes: 7 additions & 0 deletions fastapi_crudrouter/core/sqlalchemy.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

from . import CRUDGenerator, NOT_FOUND, _utils
from ._types import DEPENDENCIES, PAGINATION, PYDANTIC_SCHEMA as SCHEMA
from ._utils import create_schema_default_factory

try:
from sqlalchemy.orm import Session
Expand Down Expand Up @@ -102,6 +103,12 @@ def route(
model: self.create_schema, # type: ignore
db: Session = Depends(self.db_func),
) -> Model:
model, _ = create_schema_default_factory(
schema_cls=self.schema,
create_schema_instance=model,
pk_field_name=self._pk,
)

try:
db_model: Model = self.db_model(**model.dict())
db.add(db_model)
Expand Down
13 changes: 10 additions & 3 deletions fastapi_crudrouter/core/tortoise.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

from . import CRUDGenerator, NOT_FOUND
from ._types import DEPENDENCIES, PAGINATION, PYDANTIC_SCHEMA as SCHEMA
from ._utils import get_pk_type, create_schema_default_factory

try:
from tortoise.models import Model
Expand Down Expand Up @@ -40,6 +41,7 @@ def __init__(

self.db_model = db_model
self._pk: str = db_model.describe()["pk_field"]["db_column"]
self._pk_type: type = get_pk_type(schema, self._pk)

super().__init__(
schema=schema,
Expand Down Expand Up @@ -68,7 +70,7 @@ async def route(pagination: PAGINATION = self.pagination) -> List[Model]:
return route

def _get_one(self, *args: Any, **kwargs: Any) -> CALLABLE:
async def route(item_id: int) -> Model:
async def route(item_id: self._pk_type) -> Model:
model = await self.db_model.filter(id=item_id).first()

if model:
Expand All @@ -80,6 +82,11 @@ async def route(item_id: int) -> Model:

def _create(self, *args: Any, **kwargs: Any) -> CALLABLE:
async def route(model: self.create_schema) -> Model: # type: ignore
model, _ = create_schema_default_factory(
schema_cls=self.schema,
create_schema_instance=model,
pk_field_name=self._pk,
)
db_model = self.db_model(**model.dict())
await db_model.save()

Expand All @@ -89,7 +96,7 @@ async def route(model: self.create_schema) -> Model: # type: ignore

def _update(self, *args: Any, **kwargs: Any) -> CALLABLE:
async def route(
item_id: int, model: self.update_schema # type: ignore
item_id: self._pk_type, model: self.update_schema # type: ignore
) -> Model:
await self.db_model.filter(id=item_id).update(
**model.dict(exclude_unset=True)
Expand All @@ -106,7 +113,7 @@ async def route() -> List[Model]:
return route

def _delete_one(self, *args: Any, **kwargs: Any) -> CALLABLE:
async def route(item_id: int) -> Model:
async def route(item_id: self._pk_type) -> Model:
model: Model = await self._get_one()(item_id)
await self.db_model.filter(id=item_id).delete()

Expand Down
14 changes: 13 additions & 1 deletion tests/__init__.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
from pydantic import BaseModel
from uuid import uuid4
from pydantic import BaseModel, Field

from .conf import config

PAGINATION_SIZE = 10
CUSTOM_TAGS = ["Tag1", "Tag2"]
POTATO_TAGS = ["Potato"]


class ORMModel(BaseModel):
Expand All @@ -24,6 +26,16 @@ class Potato(PotatoCreate, ORMModel):
pass


class DefaultFactoryPotatoCreate(BaseModel):
color: str
mass: float


class DefaultFactoryPotato(DefaultFactoryPotatoCreate, ORMModel):
id: str = Field(default_factory=lambda: str(uuid4()))
pass


class CustomPotato(PotatoCreate):
potato_id: int

Expand Down
17 changes: 17 additions & 0 deletions tests/implementations/databases_.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,9 @@
CustomPotato,
PAGINATION_SIZE,
Potato,
DefaultFactoryPotato,
PotatoType,
POTATO_TAGS,
CUSTOM_TAGS,
config,
)
Expand Down Expand Up @@ -47,6 +49,13 @@ def databases_implementation(db_uri: str):
Column("color", String),
Column("type", String),
)
defaultfactorypotatoes = Table(
"defaultfactorypotatoes",
metadata,
Column("id", String, primary_key=True),
Column("color", String),
Column("mass", Float),
)
carrots = Table(
"carrots",
metadata,
Expand Down Expand Up @@ -74,6 +83,14 @@ async def shutdown():
prefix="potato",
paginate=PAGINATION_SIZE,
),
dict(
database=database,
table=defaultfactorypotatoes,
schema=DefaultFactoryPotato,
prefix="defaultfactorypotato",
tags=POTATO_TAGS,
paginate=PAGINATION_SIZE,
),
dict(
database=database,
table=carrots,
Expand Down
16 changes: 16 additions & 0 deletions tests/implementations/gino_.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@
CarrotCreate,
CarrotUpdate,
CustomPotato,
DefaultFactoryPotato,
POTATO_TAGS,
Potato,
PotatoType,
config,
Expand Down Expand Up @@ -47,6 +49,12 @@ class PotatoModel(db.Model):
color = db.Column(db.String)
type = db.Column(db.String)

class DefaultFactoryPotatoModel(db.Model):
__tablename__ = "defaultfactorypotatoes"
id = db.Column(db.String, primary_key=True, index=True)
mass = db.Column(db.Float)
color = db.Column(db.String)

class CarrotModel(db.Model):
__tablename__ = "carrots"
id = db.Column(db.Integer, primary_key=True, index=True)
Expand All @@ -63,6 +71,14 @@ class CarrotModel(db.Model):
prefix="potato",
paginate=PAGINATION_SIZE,
),
dict(
schema=DefaultFactoryPotato,
db_model=DefaultFactoryPotatoModel,
db=db,
prefix="defaultfactorypotato",
tags=POTATO_TAGS,
paginate=PAGINATION_SIZE,
),
dict(
schema=Carrot,
db_model=CarrotModel,
Expand Down
16 changes: 14 additions & 2 deletions tests/implementations/memory.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,22 @@
from fastapi import FastAPI

from fastapi_crudrouter import MemoryCRUDRouter
from tests import Potato, Carrot, CarrotUpdate, PAGINATION_SIZE, CUSTOM_TAGS
from tests import (
Potato,
DefaultFactoryPotato,
Carrot,
CarrotUpdate,
PAGINATION_SIZE,
CUSTOM_TAGS,
POTATO_TAGS,
)


def memory_implementation(**kwargs):
app = FastAPI()
router_settings = [
dict(schema=Potato, paginate=PAGINATION_SIZE),
dict(schema=DefaultFactoryPotato, paginate=PAGINATION_SIZE, tags=POTATO_TAGS),
dict(schema=Carrot, update_schema=CarrotUpdate, tags=CUSTOM_TAGS),
]

Expand All @@ -17,4 +26,7 @@ def memory_implementation(**kwargs):
if __name__ == "__main__":
import uvicorn

uvicorn.run(memory_implementation(), port=5000)
app, route_type, routes = memory_implementation()
for route in routes:
app.include_router(route_type(**route))
uvicorn.run(app, port=5000)
Loading