Skip to content

Commit

Permalink
Add tests to existing and refactored codebase
Browse files Browse the repository at this point in the history
  • Loading branch information
tarsil committed Sep 20, 2023
1 parent 61372f9 commit d01fe99
Show file tree
Hide file tree
Showing 30 changed files with 1,327 additions and 132 deletions.
3 changes: 3 additions & 0 deletions mongoz/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
Time,
)
from .core.db.querysets.base import QuerySet
from .core.db.querysets.expressions import Expression, SortExpression
from .core.db.querysets.operators import Q

__all__ = [
Expand All @@ -42,6 +43,7 @@
"Embed",
"EmailField",
"EmbeddedDocument",
"Expression",
"Index",
"IndexType",
"Integer",
Expand All @@ -52,6 +54,7 @@
"Q",
"QuerySet",
"Registry",
"SortExpression",
"String",
"Time",
"UUID",
Expand Down
4 changes: 3 additions & 1 deletion mongoz/core/db/documents/base.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import ClassVar, Dict, List, Type, TypeVar, Union
from typing import ClassVar, Dict, List, Mapping, Type, TypeVar, Union

import bson
import pydantic
Expand All @@ -8,6 +8,7 @@

from mongoz.core.db.documents._internal import DescriptiveMeta
from mongoz.core.db.documents.metaclasses import BaseModelMeta, MetaInfo
from mongoz.core.db.fields.base import BaseField
from mongoz.core.db.fields.core import ObjectId
from mongoz.core.db.querysets.base import QuerySet
from mongoz.core.db.querysets.expressions import Expression
Expand Down Expand Up @@ -65,4 +66,5 @@ def __str__(self) -> str:


class MongozBaseModel(BaseMongoz):
__mongoz_fields__: ClassVar[Mapping[str, Type["BaseField"]]]
id: Union[ObjectId, None] = pydantic.Field(alias="_id")
20 changes: 18 additions & 2 deletions mongoz/core/db/documents/document.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,16 @@
from typing import ClassVar, List, Sequence, Type, Union
from typing import ClassVar, List, Mapping, Sequence, Type, TypeVar, Union

import bson
from bson.errors import InvalidId
from pydantic import BaseModel

from mongoz.core.db.documents.base import MongozBaseModel
from mongoz.core.db.documents.metaclasses import EmbeddedModelMetaClass
from mongoz.core.db.fields.base import BaseField
from mongoz.exceptions import InvalidKeyError

T = TypeVar("T", bound="Document")


class Document(MongozBaseModel):
"""
Expand Down Expand Up @@ -33,7 +38,7 @@ async def create_many(
if not all(isinstance(model, cls) for model in models):
raise TypeError(f"All models must be of type {cls.__name__}")

data = {model.model_dump(exclude={"id"}) for model in models}
data = (model.model_dump(exclude={"id"}) for model in models)
results = await cls.meta.collection._collection.insert_many(data)
for model, inserted_id in zip(models, results.inserted_ids, strict=True):
model.id = inserted_id
Expand Down Expand Up @@ -97,6 +102,16 @@ async def save(self: Type["Document"]) -> Type["Document"]:
setattr(self, k, v)
return self

@classmethod
async def get_document_by_id(cls: Type[T], id: Union[str, bson.ObjectId]) -> T:
if isinstance(id, str):
try:
id = bson.ObjectId(id)
except InvalidId as e:
raise InvalidKeyError(f'"{id}" is not a valid ObjectId') from e

return await cls.query({"_id": id}).get()

def __repr__(self) -> str:
return str(self)

Expand All @@ -109,4 +124,5 @@ class EmbeddedDocument(BaseModel, metaclass=EmbeddedModelMetaClass):
Graphical representation of an Embedded document.
"""

__mongoz_fields__: ClassVar[Mapping[str, Type["BaseField"]]]
__embedded__: ClassVar[bool] = True
14 changes: 6 additions & 8 deletions mongoz/core/db/documents/metaclasses.py
Original file line number Diff line number Diff line change
Expand Up @@ -236,9 +236,6 @@ def __search_for_fields(base: Type, attrs: Any) -> None:
if "id" in new_class.model_fields:
new_class.model_fields["id"].default = None

# # Update the model_fields are updated to the latest
# new_class.model_fields = model_fields

# Abstract classes do not allow multiple managers. This make sure it is enforced.
if meta.abstract:
managers = [k for k, v in attrs.items() if isinstance(v, Manager)]
Expand Down Expand Up @@ -313,6 +310,7 @@ def __search_for_fields(base: Type, attrs: Any) -> None:
new_field = MongozField(pydantic_field=field, model_class=field.annotation)
mongoz_fields[field_name] = new_field

new_class.Meta = meta
new_class.__mongoz_fields__ = mongoz_fields
return new_class

Expand All @@ -329,14 +327,14 @@ class EmbeddedModelMetaClass(ModelMetaclass):
def __new__(cls, name: str, bases: Tuple[Type, ...], attrs: Any) -> Any:
attrs, model_fields = extract_field_annotations_and_defaults(attrs)
cls.__mongoz_fields__ = model_fields
cls = super().__new__(cls, name, bases, attrs)
new_class = super().__new__(cls, name, bases, attrs)

mongoz_fields: Dict[str, MongozField] = {}
for field_name, field in cls.model_fields.items():
for field_name, field in new_class.model_fields.items():
if not field.alias:
field.alias = field_name
new_field = MongozField(pydantic_field=field, model_class=cls)
new_field = MongozField(pydantic_field=field, model_class=new_class)
mongoz_fields[field_name] = new_field

cls.__mongoz_fields__ = mongoz_fields
return cls
new_class.__mongoz_fields__ = mongoz_fields
return new_class
70 changes: 0 additions & 70 deletions mongoz/core/db/documents/model_proxy.py

This file was deleted.

30 changes: 0 additions & 30 deletions mongoz/core/db/documents/row.py

This file was deleted.

12 changes: 6 additions & 6 deletions mongoz/core/db/fields/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@
mongoz_setattr = object.__setattr__

if TYPE_CHECKING:
from mongoz.core.db.documents.document import Document
from mongoz.core.db.documents.document import EmbeddedDocument


CLASS_DEFAULTS = ["cls", "__class__", "kwargs"]
Expand Down Expand Up @@ -401,14 +401,14 @@ class Array(FieldFactory, list):

def __new__( # type: ignore
cls,
document: type,
type_of: type,
**kwargs: Any,
) -> BaseField:
kwargs = {
**kwargs,
**{k: v for k, v in locals().items() if k not in CLASS_DEFAULTS},
}
kwargs["list_type"] = document
kwargs["list_type"] = type_of
return super().__new__(cls, **kwargs)


Expand All @@ -432,7 +432,7 @@ class Embed(FieldFactory):

def __new__( # type: ignore
cls,
document: Type["Document"],
document: Type["EmbeddedDocument"],
**kwargs: Any,
) -> BaseField:
kwargs = {
Expand All @@ -444,10 +444,10 @@ def __new__( # type: ignore

@classmethod
def validate_field(cls, **kwargs: Any) -> None:
from mongoz.core.db.documents.document import Document, EmbeddedDocument
from mongoz.core.db.documents.document import EmbeddedDocument

document = kwargs.get("document")
if not issubclass(document, (Document, EmbeddedDocument)):
if not issubclass(document, EmbeddedDocument):
raise FieldDefinitionError(
"'document' must be of type mongoz.Document or mongoz.EmbeddedDocument"
)
5 changes: 5 additions & 0 deletions mongoz/core/db/querysets/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
from .base import QuerySet
from .expressions import Expression, SortExpression
from .operators import Q

__all__ = ["Expression", "Q", "QuerySet", "SortExpression"]
44 changes: 33 additions & 11 deletions mongoz/core/db/querysets/base.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
from typing import TYPE_CHECKING, Any, AsyncGenerator, Dict, Generic, List, Type, TypeVar, Union

import pydantic

from mongoz.core.db.datastructures import Order
from mongoz.core.db.fields.base import BaseField
from mongoz.core.db.fields.base import MongozField
from mongoz.core.db.querysets.expressions import Expression, SortExpression
from mongoz.exceptions import MultipleObjectsReturned, ObjectNotFound
from mongoz.exceptions import DocumentNotFound, MultipleDumentsReturned
from mongoz.protocols.queryset import QuerySetProtocol

if TYPE_CHECKING:
Expand Down Expand Up @@ -80,22 +82,20 @@ async def last(self) -> Union[T, None]:
"""
Returns the last document of a matching criteria.
"""

objects = await self.all()
if not objects:
return None
return objects[-1]

async def get(self) -> T:
"""Gets the document of a matching criteria."""
objects = await self.limit(2).all()
if len(objects) == 0:
raise ObjectNotFound()
raise DocumentNotFound()
elif len(objects) == 2:
raise MultipleObjectsReturned()
raise MultipleDumentsReturned()
return objects[0]

async def get_or_create(self, defaults: Union[Dict[str, Any], None]) -> T:
async def get_or_create(self, defaults: Union[Dict[str, Any], None] = None) -> T:
if not defaults:
defaults = {}

Expand Down Expand Up @@ -135,22 +135,44 @@ def sort(self, key: Any, direction: Union[Order, None] = None) -> "QuerySet[T]":
for key_dir in key:
sort_expression = SortExpression(*key_dir)
self._sort.append(sort_expression)
elif isinstance(key, (str, BaseField)):
elif isinstance(key, (str, MongozField)):
sort_expression = SortExpression(key, direction)
self._sort.append(sort_expression)
else:
self._sort.append(key)
return self

def query(self, *args: Union[bool, Dict, Expression]) -> "QuerySet[T]":
"""Filter query criteria."""

for arg in args:
assert isinstance(arg, (dict, Expression)), "Invalid argument to Query"
if isinstance(arg, dict):
query_expressions = Expression.unpack(arg)
self._filter.extend(query_expressions)
else:
self._filter.append(arg)

return self

async def update(self, **kwargs: Any) -> List[T]:
field_definitions = {
name: (annotations, ...)
for name, annotations in self.model_class.__annotations__.items()
if name in kwargs
}

if field_definitions:
pydantic_model: Type[pydantic.BaseModel] = pydantic.create_model(
__model_name=self.model_class.__name__,
__config__=self.model_class.model_config,
**field_definitions,
)
model = pydantic_model.model_validate(kwargs)
values = model.model_dump()

filter_query = Expression.compile_many(self._filter)
await self._collection.update_many(filter_query, {"$set": values})

_filter = [expression for expression in self._filter if expression.key not in values]
_filter.extend([Expression(key, "$eq", value) for key, value in values.items()])

self._filter = _filter
return await self.all()
4 changes: 2 additions & 2 deletions mongoz/core/db/querysets/operators.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,13 +37,13 @@ def or_(cls, *args: Union[bool, Expression]) -> Expression:

@classmethod
def contains(cls, key: Any, value: Any) -> Expression:
if key.annotation is str:
if key.pydantic_field.annotation is str:
return Expression(key=key, operator="$regex", value=value)
return Expression(key=key, operator="$eq", value=value)

@classmethod
def pattern(cls, key: Any, value: Union[str, re.Pattern]) -> Expression:
if key.annotation is str:
if key.pydantic_field.annotation is str:
expression = value.pattern if isinstance(value, re.Pattern) else value
return Expression(key=key, operator="$regex", value=expression)
name = key if isinstance(key, str) else key._name
Expand Down
Loading

0 comments on commit d01fe99

Please sign in to comment.