-
Notifications
You must be signed in to change notification settings - Fork 16
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
INTPYTHON-487 Polymorphic Collection Support #269
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,6 +1,9 @@ | ||
from itertools import chain | ||
|
||
from django.core.exceptions import FieldError | ||
from django.db import NotSupportedError, models | ||
|
||
from .managers import EmbeddedModelManager | ||
from .managers import EmbeddedModelManager, MultiMongoManager | ||
|
||
|
||
class EmbeddedModel(models.Model): | ||
|
@@ -14,3 +17,136 @@ def delete(self, *args, **kwargs): | |
|
||
def save(self, *args, **kwargs): | ||
raise NotSupportedError("EmbeddedModels cannot be saved.") | ||
|
||
|
||
class ModelBaseOverride(models.base.ModelBase): | ||
__excluded_fieldnames = ("_t", "id") | ||
|
||
def __new__(cls, name, bases, attrs, **kwargs): | ||
"""An override to the ModelBase which inspects inherited Model | ||
definitions and passes down the field names and table reference | ||
from parent to child model. | ||
** REMAINING TODO | ||
- Handle Index Creation | ||
- Tests | ||
""" | ||
parents = [b for b in bases if isinstance(b, models.base.ModelBase)] | ||
|
||
# if no ModelBase instances found, this is the first inherited MultiModel | ||
if not parents: | ||
return super().__new__(cls, name, bases, attrs, **kwargs) | ||
|
||
# Recursively fetch all fields of a class. | ||
# Only conclude the loop when we get the MultiModel class | ||
# We cannot explicitly pass a reference to the MultiModel class | ||
# because this builds a circluar dependency | ||
holder = bases | ||
traverse = holder[0] | ||
if traverse.__name__ != "MultiModel" and hasattr(traverse, "_meta"): | ||
while traverse and traverse.__name__ != "MultiModel" and hasattr(traverse, "_meta"): | ||
traverse = traverse._meta._bases[0] if traverse._meta._bases else None | ||
holder = (traverse,) | ||
|
||
parent_fields = [] | ||
|
||
# Set up managed + default db if not set | ||
if hasattr(parents[0], "_meta") and parents[0].__name__ != "MultiModel": | ||
if not attrs.get("Meta"): | ||
|
||
class Meta: | ||
db_table = parents[0]._meta.db_table | ||
managed = False | ||
|
||
attrs["Meta"] = Meta() | ||
|
||
elif meta := attrs.get("Meta"): | ||
if not getattr(meta, "db_table", None): | ||
meta.db_table = parents[0]._meta.db_table | ||
if not getattr(meta, "managed", None): | ||
meta.managed = False | ||
parent_fields = set(parents[0]._meta.local_fields + parents[0]._meta.local_many_to_many) | ||
|
||
# The parent class will not be passed to the __new__ construction | ||
# because we will leverage Django's multi-table inheritance | ||
# which would lead to more complications on field resolution | ||
new_attrs = {**attrs} | ||
|
||
for field in parent_fields: | ||
if not models.base._has_contribute_to_class(field): | ||
if field.name in new_attrs: | ||
raise FieldError( | ||
f"Local field {field.name!r} in class {name!r} clashes with field of " | ||
f"the same name from base class {parents[0].__name__!r}." | ||
) | ||
new_attrs[field.name] = field | ||
|
||
# Construct new class without passing the parent reference, but adding | ||
# every new (derived) attribute to the django class | ||
new_cls = super().__new__(cls, name, holder, new_attrs, **kwargs) | ||
|
||
new_fields = chain( | ||
new_cls._meta.local_fields, | ||
new_cls._meta.local_many_to_many, | ||
new_cls._meta.private_fields, | ||
) | ||
field_names = {f.name for f in new_fields} | ||
|
||
for field in parent_fields: | ||
if field.primary_key or field.name in ModelBaseOverride.__excluded_fieldnames: | ||
continue | ||
if models.base._has_contribute_to_class(field): | ||
if ( | ||
field.name in field_names | ||
and field.name not in ModelBaseOverride.__excluded_fieldnames | ||
): | ||
raise FieldError( | ||
f"Local field {field.name!r} in class {name!r} clashes with field of " | ||
f"the same name from base class {parents[0].__name__!r}." | ||
) | ||
|
||
# if not hasattr(new_cls, field.name): | ||
new_cls.add_to_class(field.name, field) | ||
# Add each value as a subclass to its parent MultiModel object | ||
for _base in parents: | ||
# equivalent of if _base is MultiModel | ||
if hasattr(_base, "_subclasses"): | ||
_base._subclasses.setdefault(_base, []).append(new_cls) | ||
|
||
new_cls._meta._bases = parents | ||
new_cls._meta.parents = {} | ||
return new_cls | ||
|
||
|
||
class MultiModel(models.Model, metaclass=ModelBaseOverride): | ||
"""Manager handles tracking all inherited subclasses to be used in the MultiMongoManager query""" | ||
|
||
_subclasses = {} | ||
|
||
def __init_subclass__(cls, **kwargs): | ||
super().__init_subclass__(**kwargs) | ||
for _base in cls.__bases__: | ||
if issubclass(_base, MultiModel): | ||
MultiModel._subclasses.setdefault(_base, []).append(cls) | ||
|
||
# Get all the subclasses for my model | ||
@classmethod | ||
def subclasses(cls): | ||
stack = [cls] | ||
acc = set() | ||
while stack: | ||
node = stack.pop() | ||
stack.extend(cls._subclasses.get(node, [])) | ||
acc.add(node) | ||
return [obj.__name__ for obj in acc] | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Maybe adding from base import A
class B(A):
pass file2.py from base import A
class B(A):
pass base.py class A:
pass main.py from file1 import B
from file2 import B as B2
print(B.__name__, B2.__name__) # it prints-> B, B
print(B.__module__, B2.__module__). # it prints -> file2 file3 |
||
|
||
_t = models.CharField(max_length=255, editable=False) | ||
objects = MultiMongoManager() | ||
|
||
# Save the classname as the _t before saving | ||
def save(self, *args, **kwargs): | ||
if not self._t: | ||
self._t = self.__class__.__name__ | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. If the above comment is ok, we should also modify this one. |
||
super().save(*args, **kwargs) | ||
|
||
class Meta: | ||
abstract = True |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,5 +1,6 @@ | ||
from itertools import chain | ||
|
||
from django.apps import apps | ||
from django.core.exceptions import FieldDoesNotExist | ||
from django.db import connections | ||
from django.db.models import QuerySet | ||
|
@@ -13,6 +14,13 @@ def raw_aggregate(self, pipeline, using=None): | |
return RawQuerySet(pipeline, model=self.model, using=using) | ||
|
||
|
||
class MultiMongoQuerySet(MongoQuerySet): | ||
def __iter__(self, *args, **kwargs): | ||
for obj in super().__iter__(*args, **kwargs): | ||
model_class = apps.get_model(obj._meta.app_label, obj._t) | ||
yield model_class.objects.get(pk=obj.pk) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 🤔 what is the flow here? is it doing N+1 queries to get the data? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Jib mentioned that it's a temporary hack. |
||
|
||
|
||
class RawQuerySet(BaseRawQuerySet): | ||
def __init__(self, pipeline, model=None, using=None): | ||
super().__init__(pipeline, model=model, using=using) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
my model? our model. (don't know if it is good to post a meme here)