Skip to content

Commit

Permalink
serverless: add create_objects and fix several issues (#178)
Browse files Browse the repository at this point in the history
  • Loading branch information
viseshrp authored Nov 12, 2024
1 parent 068a700 commit 55daa45
Show file tree
Hide file tree
Showing 5 changed files with 331 additions and 159 deletions.
6 changes: 6 additions & 0 deletions src/ansys/dynamicreporting/core/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,3 +126,9 @@ class MultipleObjectsReturnedError(ADRException):
"""Exception raised if only one object was expected, but multiple were returned."""

detail = "get() returned more than one object."


class IntegrityError(ADRException):
"""Exception raised if there is a constraint violation while saving an object in the database."""

detail = "A database integrity check failed."
69 changes: 51 additions & 18 deletions src/ansys/dynamicreporting/core/serverless/adr.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
from collections.abc import Iterable
import os
from pathlib import Path
import platform
import shutil
import sys
from typing import Any, Optional, Type, Union
import uuid
Expand Down Expand Up @@ -136,12 +138,29 @@ def _get_install_directory(self, ansys_installation: str) -> Path:
raise InvalidAnsysPath(f"Unable to detect an installation in: {','.join(dirs_to_check)}")

def _check_dir(self, dir_):
dir_path = Path(dir_)
dir_path = Path(dir_) if not isinstance(dir_, Path) else dir_
if not dir_path.exists() or not dir_path.is_dir():
self._logger.error(f"Invalid directory path: {dir_}")
raise InvalidPath(extra_detail=dir_)
return dir_path

def _migrate_db(self, db):
try: # upgrade databases
management.call_command("migrate", "--no-input", "--database", db, verbosity=0)
except Exception as e:
self._logger.error(f"{e}")
raise DatabaseMigrationError(extra_detail=str(e))
else:
from django.contrib.auth.models import Group, Permission, User

if not User.objects.using(db).filter(is_superuser=True).exists():
user = User.objects.using(db).create_superuser("nexus", "", "cei")
# include the nexus group (with all permissions)
nexus_group, created = Group.objects.using(db).get_or_create(name="nexus")
if created:
nexus_group.permissions.set(Permission.objects.using(db).all())
nexus_group.user_set.add(user)

def setup(self, collect_static: bool = False) -> None:
from django.conf import settings

Expand Down Expand Up @@ -207,22 +226,11 @@ def setup(self, collect_static: bool = False) -> None:
raise ImproperlyConfiguredError(extra_detail=str(e))

# migrations
if self._db_directory is not None:
try: # upgrades all databases
management.call_command("migrate", "--no-input", verbosity=0)
except Exception as e:
self._logger.error(f"{e}")
raise DatabaseMigrationError(extra_detail=str(e))
else:
from django.contrib.auth.models import Group, Permission, User

if not User.objects.filter(is_superuser=True).exists():
user = User.objects.create_superuser("nexus", "", "cei")
# include the nexus group (with all permissions)
nexus_group, created = Group.objects.get_or_create(name="nexus")
if created:
nexus_group.permissions.set(Permission.objects.all())
nexus_group.user_set.add(user)
if self._databases:
for db in self._databases:
self._migrate_db(db)
elif self._db_directory is not None:
self._migrate_db("default")

# geometry migration
try:
Expand Down Expand Up @@ -339,8 +347,33 @@ def query(
self,
query_type: Union[Session, Dataset, Type[Item], Type[Template]],
query: str = "",
**kwargs: Any,
) -> ObjectSet:
if not issubclass(query_type, (Item, Template, Session, Dataset)):
self._logger.error(f"{query_type} is not valid")
raise TypeError(f"{query_type} is not valid")
return query_type.find(query=query)
return query_type.find(query=query, **kwargs)

@staticmethod
def create_objects(
objects: Union[list, ObjectSet],
**kwargs: Any,
) -> int:
if not isinstance(objects, Iterable):
raise ADRException("objects must be an iterable")
count = 0
for obj in objects:
if kwargs.get("using", "default") != obj.db:
# required if copying across databases
obj.reinit()
obj.save(**kwargs)
count += 1
return count

def _is_sqlite(self, database: str) -> bool:
return "sqlite" in self._databases[database]["ENGINE"]

def _get_db_dir(self, database: str) -> str:
if self._is_sqlite(database):
return self._databases[database]["NAME"]
return ""
186 changes: 131 additions & 55 deletions src/ansys/dynamicreporting/core/serverless/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,15 @@
ObjectDoesNotExist,
ValidationError,
)
from django.db import DatabaseError
from django.db import DataError
from django.db.models import Model, QuerySet
from django.db.models.base import subclass_exception
from django.db.models.manager import Manager
from django.db.utils import IntegrityError as DBIntegrityError

from ..exceptions import (
ADRException,
IntegrityError,
MultipleObjectsReturnedError,
ObjectDoesNotExistError,
ObjectNotSavedError,
Expand All @@ -42,7 +44,7 @@ def handle_field_errors(func):
def wrapper(*args, **kwargs):
try:
return func(*args, **kwargs)
except (FieldError, FieldDoesNotExist, ValidationError, DatabaseError) as e:
except (FieldError, FieldDoesNotExist, ValidationError, DataError) as e:
raise ADRException(extra_detail=f"One or more fields set or accessed are invalid: {e}")

return wrapper
Expand Down Expand Up @@ -97,6 +99,13 @@ def __new__(
parents,
namespace.get("__module__"),
)
add_exception_to_cls(
"IntegrityError",
IntegrityError,
new_cls,
parents,
namespace.get("__module__"),
)
# all classes must be dataclasses
new_cls = dataclass(eq=False, order=False, repr=False)(new_cls)
return new_cls
Expand All @@ -121,7 +130,7 @@ def __getattribute__(cls, name):


class BaseModel(metaclass=BaseMeta):
guid: UUID = field(init=False, compare=False, kw_only=True, default_factory=uuid.uuid1)
guid: UUID = field(compare=False, kw_only=True, default_factory=uuid.uuid1)
tags: str = field(compare=False, kw_only=True, default="")
_saved: bool = field(
init=False, compare=False, default=False
Expand Down Expand Up @@ -209,19 +218,108 @@ def _get_field_names(cls, with_types=False, include_private=False):
fields_.append((f.name, f.type) if with_types else f.name)
return tuple(fields_)

def _get_var_field_names(self, include_private=False):
fields_ = []
for f in vars(self).keys():
if not include_private and f.startswith("_"):
continue
fields_.append(f)
return tuple(fields_)

@classmethod
def _get_all_field_names(cls):
def _get_prop_field_names(cls):
"""Returns a list of all field names from a dataclass, including properties."""
property_fields = []
for name, value in inspect.getmembers(cls):
if isinstance(value, property):
property_fields.append(name)
return tuple(property_fields) + cls._get_field_names()
return tuple(property_fields)

@property
def saved(self):
def saved(self) -> bool:
return self._saved

@property
def _orm_saved(self) -> bool:
return not self._orm_instance._state.adding

@property
def _orm_db(self) -> str:
return self._orm_instance._state.db

@property
def db(self):
return self._orm_db

def as_dict(self):
out_dict = {}
# use a combination of vars and fields
cls_fields = set(self._get_field_names() + self._get_var_field_names())
for field_ in cls_fields:
if field_.startswith("_"):
continue
value = getattr(self, field_, None)
if value is None: # skip and use defaults
continue
out_dict[field_] = value
return out_dict

def _prepare_for_save(self, **kwargs):
self._saved = False

target_db = kwargs.pop("using", "default")
cls_fields = self._get_field_names() + self._get_prop_field_names()
model_fields = self._get_orm_field_names(self._orm_instance)
for field_ in cls_fields:
if field_ not in model_fields:
continue
value = getattr(self, field_, None)
if value is None: # skip and use defaults
continue
if isinstance(value, list):
objs = [o._orm_instance for o in value]
getattr(self._orm_instance, field_).add(*objs)
else:
if isinstance(value, BaseModel): # relations
try:
value = value._orm_instance.__class__.objects.using(target_db).get(
guid=value.guid
)
except ObjectDoesNotExist as e:
raise value.__class__.DoesNotExist(
extra_detail=f"Object with guid '{value.guid}'" f" does not exist: {e}"
)
# for all others
setattr(self._orm_instance, field_, value)

return self

def reinit(self):
self._orm_instance = self.__class__._orm_model_cls()

@handle_field_errors
def save(self, **kwargs):
try:
obj = self._prepare_for_save(**kwargs)
obj._orm_instance.save(**kwargs)
except DBIntegrityError as e:
raise self.__class__.IntegrityError(
extra_detail=f"Save failed for object with guid '{self.guid}': {e}"
)
except Exception as e:
raise e
else:
obj._saved = True

def delete(self, **kwargs):
if not self._saved:
raise self.__class__.NotSaved(
extra_detail=f"Delete failed for object with guid '{self.guid}'."
)
count, _ = self._orm_instance.delete(**kwargs)
self._saved = False
return count

@classmethod
def from_db(cls, orm_instance, parent=None):
cls_fields = dict(cls._get_field_names(with_types=True, include_private=True))
Expand Down Expand Up @@ -284,56 +382,21 @@ def from_db(cls, orm_instance, parent=None):
obj._saved = True
return obj

@handle_field_errors
def save(self, **kwargs):
self._saved = False # reset

cls_fields = self._get_all_field_names()
model_fields = self._get_orm_field_names(self._orm_instance)
for field_ in cls_fields:
if field_ not in model_fields:
continue
value = getattr(self, field_, None)
if value is None:
continue
if isinstance(value, list):
obj_list = []
for obj in value:
obj_list.append(obj._orm_instance)
getattr(self._orm_instance, field_).add(*obj_list)
else:
if isinstance(value, BaseModel): # relations
try:
value = value._orm_instance.__class__.objects.using(
kwargs.get("using", "default")
).get(guid=value.guid)
except ObjectDoesNotExist:
raise value.__class__.DoesNotExist
# for all others
setattr(self._orm_instance, field_, value)

self._orm_instance.save(**kwargs)
self._saved = True

@classmethod
@handle_field_errors
def create(cls, **kwargs):
target_db = kwargs.pop("using", "default")
obj = cls(**kwargs)
obj.save(force_insert=True)
obj.save(force_insert=True, using=target_db)
return obj

def delete(self, **kwargs):
if not self._saved:
raise self.__class__.NotSaved(extra_detail="Delete failed")
count, _ = self._orm_instance.delete(**kwargs)
self._saved = False
return count

@classmethod
@handle_field_errors
def get(cls, **kwargs):
try:
orm_instance = cls._orm_model_cls.objects.get(**kwargs)
orm_instance = cls._orm_model_cls.objects.using(kwargs.pop("using", "default")).get(
**kwargs
)
except ObjectDoesNotExist:
raise cls.DoesNotExist
except MultipleObjectsReturned:
Expand All @@ -343,21 +406,30 @@ def get(cls, **kwargs):

@classmethod
@handle_field_errors
def filter(cls, **kwargs):
qs = cls._orm_model_cls.objects.filter(**kwargs)
return ObjectSet(_model=cls, _orm_model=cls._orm_model_cls, _orm_queryset=qs)
def get_or_create(cls, **kwargs):
try:
return cls.get(**kwargs), False
except cls.DoesNotExist:
# Try to create an object using passed params.
try:
return cls.create(**kwargs), True
except cls.IntegrityError:
try:
return cls.get(**kwargs), False
except cls.DoesNotExist:
pass
raise

@classmethod
@handle_field_errors
def bulk_create(cls, **kwargs):
objs = cls._orm_model_cls.objects.bulk_create(**kwargs)
qs = cls._orm_model_cls.objects.filter(pk__in=[obj.pk for obj in objs])
def filter(cls, **kwargs):
qs = cls._orm_model_cls.objects.using(kwargs.pop("using", "default")).filter(**kwargs)
return ObjectSet(_model=cls, _orm_model=cls._orm_model_cls, _orm_queryset=qs)

@classmethod
@handle_field_errors
def find(cls, query="", reverse=False, sort_tag="date"):
qs = cls._orm_model_cls.find(query=query, reverse=reverse, sort_tag=sort_tag)
def find(cls, query="", **kwargs):
qs = cls._orm_model_cls.find(query=query, **kwargs)
return ObjectSet(_model=cls, _orm_model=cls._orm_model_cls, _orm_queryset=qs)

def get_tags(self):
Expand Down Expand Up @@ -426,9 +498,13 @@ def saved(self):
return self._saved

def delete(self):
count = 0
for obj in self._obj_set:
obj.delete()
count += 1
self._orm_queryset.delete()
self._obj_set = []
self._saved = False
count, _ = self._orm_queryset.delete()
return count

def values_list(self, *fields, flat=False):
Expand Down
Loading

0 comments on commit 55daa45

Please sign in to comment.