Skip to content

Commit

Permalink
SQLAlchemy 모델을 제네릭에서 지정하여 사용할 수 있도록 구현 (#7)
Browse files Browse the repository at this point in the history
  • Loading branch information
NEONKID authored Sep 21, 2021
1 parent c3dcec4 commit a64bfd3
Show file tree
Hide file tree
Showing 2 changed files with 65 additions and 55 deletions.
6 changes: 3 additions & 3 deletions poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

114 changes: 62 additions & 52 deletions pymfdata/rdb/repository.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
from contextlib import AbstractAsyncContextManager, AbstractContextManager
from typing import Callable, final, Iterator, List, Protocol, TypeVar, Optional
from typing import Callable, final, Iterator, get_args, List, Protocol, Optional, Type, TypeVar
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.future import select
from sqlalchemy.inspection import inspect
from sqlalchemy.orm import Session, Query
from sqlalchemy.sql.selectable import Select

Expand All @@ -12,33 +13,35 @@


class AsyncRepository(Protocol[_MT, _T]):
_model: _MT
_session_factory: Callable[..., AbstractAsyncContextManager]
_pk_column: str

@property
def _model(self):
return get_args(self.__orig_bases__[0])[0]

@property
def _pk_column(self) -> str:
return inspect(self._model).primary_key[0].name

async def delete_by_pk(self, pk: _T) -> bool:
item = await self.find_by_pk(pk)
if item is not None:
session: AsyncSession
async with self._session_factory() as session:
session: AsyncSession
async with self._session_factory() as session:
item = await self.find_by_pk(session, pk)
if item is not None:
await session.delete(item)
await session.commit()

return True
return False
return True

async def find_by_pk(self, pk: _T) -> Optional[_MT]:
return await self.find_by_col(**{self._pk_column: pk})
return False

@final
async def find_by_col(self, **kwargs) -> Optional[_MT]:
if not await self.is_exists(**kwargs):
return None
async def find_by_pk(self, session: AsyncSession, pk: _T) -> Optional[_MT]:
return await self.find_by_col(session, **{self._pk_column: pk})

session: AsyncSession
async with self._session_factory() as session:
item = await session.execute(self._gen_stmt_for_param(**kwargs))
return item.unique().scalars().one()
@final
async def find_by_col(self, session: AsyncSession, **kwargs) -> Optional[_MT]:
item = await session.execute(self._gen_stmt_for_param(**kwargs))
return item.unique().scalars().one_or_none()

@final
def _gen_stmt_for_param(self, **kwargs) -> Select:
Expand All @@ -61,59 +64,66 @@ async def find_all(self, **kwargs) -> List[_MT]:
async def is_exists(self, **kwargs) -> bool:
session: AsyncSession
async with self._session_factory() as session:
return await session.execute(self._gen_stmt_for_param(**kwargs).exists().select())
result = await session.execute(self._gen_stmt_for_param(**kwargs).exists().select())
return result.scalar()

@final
async def save(self, item: Base):
async def save(self, item: _MT):
session: AsyncSession
async with self._session_factory() as session:
session.add(item)
await session.commit()
await session.refresh(item)

async def update_by_pk(self, pk: _T, req: dict) -> bool:
item = await self.find_by_pk(pk)
if item is not None:
session: AsyncSession
async with self._session_factory() as session:
session: AsyncSession
async with self._session_factory() as session:
item = await self.find_by_pk(session, pk)
if item is not None:
for k, v in req.items():
if v is not None:
setattr(item, k, v)

await session.commit()
await session.refresh(item)

return True
return False
return True
return False


class SyncRepository(Protocol[_MT, _T]):
_model: _MT
_session_factory: Callable[..., AbstractContextManager]
_pk_column: str

@property
def _model(self):
return get_args(self.__orig_bases__[0])[0]

@property
def _pk_column(self) -> str:
return inspect(self._model).primary_key[0].name

@final
def count(self, **kwargs) -> int:
return self._gen_query_for_param(**kwargs).count()

def delete_by_pk(self, pk: _T) -> bool:
item = self.find_by_pk(pk)
if item is not None:
session: Session
with self._session_factory() as session:
session: Session
with self._session_factory() as session:
item = self.find_by_pk(session, pk)
if item is not None:
session.delete(item)
session.commit()

return True
return False
return True
return False

def find_by_pk(self, pk: _T) -> Optional[_MT]:
return self.find_by_col(**{self._pk_column: pk})
def find_by_pk(self, session: Session, pk: _T) -> Optional[_MT]:
return self.find_by_col(session, **{self._pk_column: pk})

@final
def find_by_col(self, **kwargs) -> Optional[_MT]:
if not self.is_exists(**kwargs):
return None

with self._session_factory() as session:
query = self._gen_query_for_param(session, **kwargs)
return query.one()
def find_by_col(self, session: Session, **kwargs) -> Optional[_MT]:
query = self._gen_query_for_param(session, **kwargs)
return query.one_or_none()

@final
def _gen_query_for_param(self, session: Session, **kwargs) -> Query:
Expand Down Expand Up @@ -144,16 +154,16 @@ def save(self, item: Base):
session.commit()

def update_by_pk(self, pk: _T, req: dict) -> bool:
item = self.find_by_pk(pk)
if item is not None:
session: Session
with self._session_factory() as session:
session: Session
with self._session_factory() as session:
item = self.find_by_pk(session, pk)
if item is not None:
for k, v in req.items():
if v is not None:
setattr(item, k, v)

await session.commit()
await session.refresh(item)
session.commit()
session.refresh(item)

return True
return False
return True
return False

0 comments on commit a64bfd3

Please sign in to comment.