Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

INTPYTHON-487 Polymorphic Collection Support #269

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion django_mongodb_backend/managers.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,18 @@
from django.db import NotSupportedError
from django.db.models.manager import BaseManager

from .queryset import MongoQuerySet
from .queryset import MongoQuerySet, MultiMongoQuerySet


class MongoManager(BaseManager.from_queryset(MongoQuerySet)):
pass


class MultiMongoManager(BaseManager.from_queryset(MultiMongoQuerySet)):
def get_queryset(self):
return super().get_queryset().filter(_t__in=self.model.subclasses())


class EmbeddedModelManager(BaseManager):
"""
Prevent all queryset operations on embedded models since they don't have
Expand Down
138 changes: 137 additions & 1 deletion django_mongodb_backend/models.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
from itertools import chain

from django.core.exceptions import FieldError
from django.db import NotSupportedError, models

from .managers import EmbeddedModelManager
from .managers import EmbeddedModelManager, MultiMongoManager


class EmbeddedModel(models.Model):
Expand All @@ -14,3 +17,136 @@ def delete(self, *args, **kwargs):

def save(self, *args, **kwargs):
raise NotSupportedError("EmbeddedModels cannot be saved.")


class ModelBaseOverride(models.base.ModelBase):
__excluded_fieldnames = ("_t", "id")

def __new__(cls, name, bases, attrs, **kwargs):
"""An override to the ModelBase which inspects inherited Model
definitions and passes down the field names and table reference
from parent to child model.
** REMAINING TODO
- Handle Index Creation
- Tests
"""
parents = [b for b in bases if isinstance(b, models.base.ModelBase)]

# if no ModelBase instances found, this is the first inherited MultiModel
if not parents:
return super().__new__(cls, name, bases, attrs, **kwargs)

# Recursively fetch all fields of a class.
# Only conclude the loop when we get the MultiModel class
# We cannot explicitly pass a reference to the MultiModel class
# because this builds a circluar dependency
holder = bases
traverse = holder[0]
if traverse.__name__ != "MultiModel" and hasattr(traverse, "_meta"):
while traverse and traverse.__name__ != "MultiModel" and hasattr(traverse, "_meta"):
traverse = traverse._meta._bases[0] if traverse._meta._bases else None
holder = (traverse,)

parent_fields = []

# Set up managed + default db if not set
if hasattr(parents[0], "_meta") and parents[0].__name__ != "MultiModel":
if not attrs.get("Meta"):

class Meta:
db_table = parents[0]._meta.db_table
managed = False

attrs["Meta"] = Meta()

elif meta := attrs.get("Meta"):
if not getattr(meta, "db_table", None):
meta.db_table = parents[0]._meta.db_table
if not getattr(meta, "managed", None):
meta.managed = False
parent_fields = set(parents[0]._meta.local_fields + parents[0]._meta.local_many_to_many)

# The parent class will not be passed to the __new__ construction
# because we will leverage Django's multi-table inheritance
# which would lead to more complications on field resolution
new_attrs = {**attrs}

for field in parent_fields:
if not models.base._has_contribute_to_class(field):
if field.name in new_attrs:
raise FieldError(
f"Local field {field.name!r} in class {name!r} clashes with field of "
f"the same name from base class {parents[0].__name__!r}."
)
new_attrs[field.name] = field

# Construct new class without passing the parent reference, but adding
# every new (derived) attribute to the django class
new_cls = super().__new__(cls, name, holder, new_attrs, **kwargs)

new_fields = chain(
new_cls._meta.local_fields,
new_cls._meta.local_many_to_many,
new_cls._meta.private_fields,
)
field_names = {f.name for f in new_fields}

for field in parent_fields:
if field.primary_key or field.name in ModelBaseOverride.__excluded_fieldnames:
continue
if models.base._has_contribute_to_class(field):
if (
field.name in field_names
and field.name not in ModelBaseOverride.__excluded_fieldnames
):
raise FieldError(
f"Local field {field.name!r} in class {name!r} clashes with field of "
f"the same name from base class {parents[0].__name__!r}."
)

# if not hasattr(new_cls, field.name):
new_cls.add_to_class(field.name, field)
# Add each value as a subclass to its parent MultiModel object
for _base in parents:
# equivalent of if _base is MultiModel
if hasattr(_base, "_subclasses"):
_base._subclasses.setdefault(_base, []).append(new_cls)

new_cls._meta._bases = parents
new_cls._meta.parents = {}
return new_cls


class MultiModel(models.Model, metaclass=ModelBaseOverride):
"""Manager handles tracking all inherited subclasses to be used in the MultiMongoManager query"""

_subclasses = {}

def __init_subclass__(cls, **kwargs):
super().__init_subclass__(**kwargs)
for _base in cls.__bases__:
if issubclass(_base, MultiModel):
MultiModel._subclasses.setdefault(_base, []).append(cls)

# Get all the subclasses for my model
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

my model? our model. (don't know if it is good to post a meme here)

@classmethod
def subclasses(cls):
stack = [cls]
acc = set()
while stack:
node = stack.pop()
stack.extend(cls._subclasses.get(node, []))
acc.add(node)
return [obj.__name__ for obj in acc]
Copy link
Collaborator

@WaVEV WaVEV Mar 20, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe adding obj.__module__ and then __name__ could potentially fix a collision. two modules could have the same name in two different files.
Lets say:
file1.py:

from base import A
class B(A):
    pass

file2.py

from base import A
class B(A):
    pass

base.py

class A:
    pass

main.py

from file1 import B
from file2 import B as B2
print(B.__name__, B2.__name__)  # it prints-> B, B
print(B.__module__, B2.__module__). # it prints -> file2 file3


_t = models.CharField(max_length=255, editable=False)
objects = MultiMongoManager()

# Save the classname as the _t before saving
def save(self, *args, **kwargs):
if not self._t:
self._t = self.__class__.__name__
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If the above comment is ok, we should also modify this one.

super().save(*args, **kwargs)

class Meta:
abstract = True
8 changes: 8 additions & 0 deletions django_mongodb_backend/queryset.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from itertools import chain

from django.apps import apps
from django.core.exceptions import FieldDoesNotExist
from django.db import connections
from django.db.models import QuerySet
Expand All @@ -13,6 +14,13 @@ def raw_aggregate(self, pipeline, using=None):
return RawQuerySet(pipeline, model=self.model, using=using)


class MultiMongoQuerySet(MongoQuerySet):
def __iter__(self, *args, **kwargs):
for obj in super().__iter__(*args, **kwargs):
model_class = apps.get_model(obj._meta.app_label, obj._t)
yield model_class.objects.get(pk=obj.pk)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🤔 what is the flow here? is it doing N+1 queries to get the data?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Jib mentioned that it's a temporary hack.



class RawQuerySet(BaseRawQuerySet):
def __init__(self, pipeline, model=None, using=None):
super().__init__(pipeline, model=model, using=using)
Expand Down
Loading