-
-
Notifications
You must be signed in to change notification settings - Fork 167
/
Copy pathtortoise.py
122 lines (99 loc) · 4.16 KB
/
tortoise.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
from typing import Any, Callable, List, Type, cast, Coroutine, Optional, Union
from . import CRUDGenerator, NOT_FOUND
from ._types import DEPENDENCIES, PAGINATION, PYDANTIC_SCHEMA as SCHEMA
from ._utils import get_pk_type, create_schema_default_factory
try:
from tortoise.models import Model
except ImportError:
Model = None # type: ignore
tortoise_installed = False
else:
tortoise_installed = True
CALLABLE = Callable[..., Coroutine[Any, Any, Model]]
CALLABLE_LIST = Callable[..., Coroutine[Any, Any, List[Model]]]
class TortoiseCRUDRouter(CRUDGenerator[SCHEMA]):
def __init__(
self,
schema: Type[SCHEMA],
db_model: Type[Model],
create_schema: Optional[Type[SCHEMA]] = None,
update_schema: Optional[Type[SCHEMA]] = None,
prefix: Optional[str] = None,
tags: Optional[List[str]] = None,
paginate: Optional[int] = None,
get_all_route: Union[bool, DEPENDENCIES] = True,
get_one_route: Union[bool, DEPENDENCIES] = True,
create_route: Union[bool, DEPENDENCIES] = True,
update_route: Union[bool, DEPENDENCIES] = True,
delete_one_route: Union[bool, DEPENDENCIES] = True,
delete_all_route: Union[bool, DEPENDENCIES] = True,
**kwargs: Any
) -> None:
assert (
tortoise_installed
), "Tortoise ORM must be installed to use the TortoiseCRUDRouter."
self.db_model = db_model
self._pk: str = db_model.describe()["pk_field"]["db_column"]
self._pk_type: type = get_pk_type(schema, self._pk)
super().__init__(
schema=schema,
create_schema=create_schema,
update_schema=update_schema,
prefix=prefix or db_model.describe()["name"].replace("None.", ""),
tags=tags,
paginate=paginate,
get_all_route=get_all_route,
get_one_route=get_one_route,
create_route=create_route,
update_route=update_route,
delete_one_route=delete_one_route,
delete_all_route=delete_all_route,
**kwargs
)
def _get_all(self, *args: Any, **kwargs: Any) -> CALLABLE_LIST:
async def route(pagination: PAGINATION = self.pagination) -> List[Model]:
skip, limit = pagination.get("skip"), pagination.get("limit")
query = self.db_model.all().offset(cast(int, skip))
if limit:
query = query.limit(limit)
return await query
return route
def _get_one(self, *args: Any, **kwargs: Any) -> CALLABLE:
async def route(item_id: self._pk_type) -> Model:
model = await self.db_model.filter(id=item_id).first()
if model:
return model
else:
raise NOT_FOUND
return route
def _create(self, *args: Any, **kwargs: Any) -> CALLABLE:
async def route(model: self.create_schema) -> Model: # type: ignore
model, _ = create_schema_default_factory(
schema_cls=self.schema,
create_schema_instance=model,
pk_field_name=self._pk,
)
db_model = self.db_model(**model.dict())
await db_model.save()
return db_model
return route
def _update(self, *args: Any, **kwargs: Any) -> CALLABLE:
async def route(
item_id: self._pk_type, model: self.update_schema # type: ignore
) -> Model:
await self.db_model.filter(id=item_id).update(
**model.dict(exclude_unset=True)
)
return await self._get_one()(item_id)
return route
def _delete_all(self, *args: Any, **kwargs: Any) -> CALLABLE_LIST:
async def route() -> List[Model]:
await self.db_model.all().delete()
return await self._get_all()(pagination={"skip": 0, "limit": None})
return route
def _delete_one(self, *args: Any, **kwargs: Any) -> CALLABLE:
async def route(item_id: self._pk_type) -> Model:
model: Model = await self._get_one()(item_id)
await self.db_model.filter(id=item_id).delete()
return model
return route