Skip to content

Commit 55daa45

Browse files
authored
serverless: add create_objects and fix several issues (#178)
1 parent 068a700 commit 55daa45

File tree

5 files changed

+331
-159
lines changed

5 files changed

+331
-159
lines changed

src/ansys/dynamicreporting/core/exceptions.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -126,3 +126,9 @@ class MultipleObjectsReturnedError(ADRException):
126126
"""Exception raised if only one object was expected, but multiple were returned."""
127127

128128
detail = "get() returned more than one object."
129+
130+
131+
class IntegrityError(ADRException):
132+
"""Exception raised if there is a constraint violation while saving an object in the database."""
133+
134+
detail = "A database integrity check failed."

src/ansys/dynamicreporting/core/serverless/adr.py

Lines changed: 51 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
1+
from collections.abc import Iterable
12
import os
23
from pathlib import Path
34
import platform
5+
import shutil
46
import sys
57
from typing import Any, Optional, Type, Union
68
import uuid
@@ -136,12 +138,29 @@ def _get_install_directory(self, ansys_installation: str) -> Path:
136138
raise InvalidAnsysPath(f"Unable to detect an installation in: {','.join(dirs_to_check)}")
137139

138140
def _check_dir(self, dir_):
139-
dir_path = Path(dir_)
141+
dir_path = Path(dir_) if not isinstance(dir_, Path) else dir_
140142
if not dir_path.exists() or not dir_path.is_dir():
141143
self._logger.error(f"Invalid directory path: {dir_}")
142144
raise InvalidPath(extra_detail=dir_)
143145
return dir_path
144146

147+
def _migrate_db(self, db):
148+
try: # upgrade databases
149+
management.call_command("migrate", "--no-input", "--database", db, verbosity=0)
150+
except Exception as e:
151+
self._logger.error(f"{e}")
152+
raise DatabaseMigrationError(extra_detail=str(e))
153+
else:
154+
from django.contrib.auth.models import Group, Permission, User
155+
156+
if not User.objects.using(db).filter(is_superuser=True).exists():
157+
user = User.objects.using(db).create_superuser("nexus", "", "cei")
158+
# include the nexus group (with all permissions)
159+
nexus_group, created = Group.objects.using(db).get_or_create(name="nexus")
160+
if created:
161+
nexus_group.permissions.set(Permission.objects.using(db).all())
162+
nexus_group.user_set.add(user)
163+
145164
def setup(self, collect_static: bool = False) -> None:
146165
from django.conf import settings
147166

@@ -207,22 +226,11 @@ def setup(self, collect_static: bool = False) -> None:
207226
raise ImproperlyConfiguredError(extra_detail=str(e))
208227

209228
# migrations
210-
if self._db_directory is not None:
211-
try: # upgrades all databases
212-
management.call_command("migrate", "--no-input", verbosity=0)
213-
except Exception as e:
214-
self._logger.error(f"{e}")
215-
raise DatabaseMigrationError(extra_detail=str(e))
216-
else:
217-
from django.contrib.auth.models import Group, Permission, User
218-
219-
if not User.objects.filter(is_superuser=True).exists():
220-
user = User.objects.create_superuser("nexus", "", "cei")
221-
# include the nexus group (with all permissions)
222-
nexus_group, created = Group.objects.get_or_create(name="nexus")
223-
if created:
224-
nexus_group.permissions.set(Permission.objects.all())
225-
nexus_group.user_set.add(user)
229+
if self._databases:
230+
for db in self._databases:
231+
self._migrate_db(db)
232+
elif self._db_directory is not None:
233+
self._migrate_db("default")
226234

227235
# geometry migration
228236
try:
@@ -339,8 +347,33 @@ def query(
339347
self,
340348
query_type: Union[Session, Dataset, Type[Item], Type[Template]],
341349
query: str = "",
350+
**kwargs: Any,
342351
) -> ObjectSet:
343352
if not issubclass(query_type, (Item, Template, Session, Dataset)):
344353
self._logger.error(f"{query_type} is not valid")
345354
raise TypeError(f"{query_type} is not valid")
346-
return query_type.find(query=query)
355+
return query_type.find(query=query, **kwargs)
356+
357+
@staticmethod
358+
def create_objects(
359+
objects: Union[list, ObjectSet],
360+
**kwargs: Any,
361+
) -> int:
362+
if not isinstance(objects, Iterable):
363+
raise ADRException("objects must be an iterable")
364+
count = 0
365+
for obj in objects:
366+
if kwargs.get("using", "default") != obj.db:
367+
# required if copying across databases
368+
obj.reinit()
369+
obj.save(**kwargs)
370+
count += 1
371+
return count
372+
373+
def _is_sqlite(self, database: str) -> bool:
374+
return "sqlite" in self._databases[database]["ENGINE"]
375+
376+
def _get_db_dir(self, database: str) -> str:
377+
if self._is_sqlite(database):
378+
return self._databases[database]["NAME"]
379+
return ""

src/ansys/dynamicreporting/core/serverless/base.py

Lines changed: 131 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -17,13 +17,15 @@
1717
ObjectDoesNotExist,
1818
ValidationError,
1919
)
20-
from django.db import DatabaseError
20+
from django.db import DataError
2121
from django.db.models import Model, QuerySet
2222
from django.db.models.base import subclass_exception
2323
from django.db.models.manager import Manager
24+
from django.db.utils import IntegrityError as DBIntegrityError
2425

2526
from ..exceptions import (
2627
ADRException,
28+
IntegrityError,
2729
MultipleObjectsReturnedError,
2830
ObjectDoesNotExistError,
2931
ObjectNotSavedError,
@@ -42,7 +44,7 @@ def handle_field_errors(func):
4244
def wrapper(*args, **kwargs):
4345
try:
4446
return func(*args, **kwargs)
45-
except (FieldError, FieldDoesNotExist, ValidationError, DatabaseError) as e:
47+
except (FieldError, FieldDoesNotExist, ValidationError, DataError) as e:
4648
raise ADRException(extra_detail=f"One or more fields set or accessed are invalid: {e}")
4749

4850
return wrapper
@@ -97,6 +99,13 @@ def __new__(
9799
parents,
98100
namespace.get("__module__"),
99101
)
102+
add_exception_to_cls(
103+
"IntegrityError",
104+
IntegrityError,
105+
new_cls,
106+
parents,
107+
namespace.get("__module__"),
108+
)
100109
# all classes must be dataclasses
101110
new_cls = dataclass(eq=False, order=False, repr=False)(new_cls)
102111
return new_cls
@@ -121,7 +130,7 @@ def __getattribute__(cls, name):
121130

122131

123132
class BaseModel(metaclass=BaseMeta):
124-
guid: UUID = field(init=False, compare=False, kw_only=True, default_factory=uuid.uuid1)
133+
guid: UUID = field(compare=False, kw_only=True, default_factory=uuid.uuid1)
125134
tags: str = field(compare=False, kw_only=True, default="")
126135
_saved: bool = field(
127136
init=False, compare=False, default=False
@@ -209,19 +218,108 @@ def _get_field_names(cls, with_types=False, include_private=False):
209218
fields_.append((f.name, f.type) if with_types else f.name)
210219
return tuple(fields_)
211220

221+
def _get_var_field_names(self, include_private=False):
222+
fields_ = []
223+
for f in vars(self).keys():
224+
if not include_private and f.startswith("_"):
225+
continue
226+
fields_.append(f)
227+
return tuple(fields_)
228+
212229
@classmethod
213-
def _get_all_field_names(cls):
230+
def _get_prop_field_names(cls):
214231
"""Returns a list of all field names from a dataclass, including properties."""
215232
property_fields = []
216233
for name, value in inspect.getmembers(cls):
217234
if isinstance(value, property):
218235
property_fields.append(name)
219-
return tuple(property_fields) + cls._get_field_names()
236+
return tuple(property_fields)
220237

221238
@property
222-
def saved(self):
239+
def saved(self) -> bool:
223240
return self._saved
224241

242+
@property
243+
def _orm_saved(self) -> bool:
244+
return not self._orm_instance._state.adding
245+
246+
@property
247+
def _orm_db(self) -> str:
248+
return self._orm_instance._state.db
249+
250+
@property
251+
def db(self):
252+
return self._orm_db
253+
254+
def as_dict(self):
255+
out_dict = {}
256+
# use a combination of vars and fields
257+
cls_fields = set(self._get_field_names() + self._get_var_field_names())
258+
for field_ in cls_fields:
259+
if field_.startswith("_"):
260+
continue
261+
value = getattr(self, field_, None)
262+
if value is None: # skip and use defaults
263+
continue
264+
out_dict[field_] = value
265+
return out_dict
266+
267+
def _prepare_for_save(self, **kwargs):
268+
self._saved = False
269+
270+
target_db = kwargs.pop("using", "default")
271+
cls_fields = self._get_field_names() + self._get_prop_field_names()
272+
model_fields = self._get_orm_field_names(self._orm_instance)
273+
for field_ in cls_fields:
274+
if field_ not in model_fields:
275+
continue
276+
value = getattr(self, field_, None)
277+
if value is None: # skip and use defaults
278+
continue
279+
if isinstance(value, list):
280+
objs = [o._orm_instance for o in value]
281+
getattr(self._orm_instance, field_).add(*objs)
282+
else:
283+
if isinstance(value, BaseModel): # relations
284+
try:
285+
value = value._orm_instance.__class__.objects.using(target_db).get(
286+
guid=value.guid
287+
)
288+
except ObjectDoesNotExist as e:
289+
raise value.__class__.DoesNotExist(
290+
extra_detail=f"Object with guid '{value.guid}'" f" does not exist: {e}"
291+
)
292+
# for all others
293+
setattr(self._orm_instance, field_, value)
294+
295+
return self
296+
297+
def reinit(self):
298+
self._orm_instance = self.__class__._orm_model_cls()
299+
300+
@handle_field_errors
301+
def save(self, **kwargs):
302+
try:
303+
obj = self._prepare_for_save(**kwargs)
304+
obj._orm_instance.save(**kwargs)
305+
except DBIntegrityError as e:
306+
raise self.__class__.IntegrityError(
307+
extra_detail=f"Save failed for object with guid '{self.guid}': {e}"
308+
)
309+
except Exception as e:
310+
raise e
311+
else:
312+
obj._saved = True
313+
314+
def delete(self, **kwargs):
315+
if not self._saved:
316+
raise self.__class__.NotSaved(
317+
extra_detail=f"Delete failed for object with guid '{self.guid}'."
318+
)
319+
count, _ = self._orm_instance.delete(**kwargs)
320+
self._saved = False
321+
return count
322+
225323
@classmethod
226324
def from_db(cls, orm_instance, parent=None):
227325
cls_fields = dict(cls._get_field_names(with_types=True, include_private=True))
@@ -284,56 +382,21 @@ def from_db(cls, orm_instance, parent=None):
284382
obj._saved = True
285383
return obj
286384

287-
@handle_field_errors
288-
def save(self, **kwargs):
289-
self._saved = False # reset
290-
291-
cls_fields = self._get_all_field_names()
292-
model_fields = self._get_orm_field_names(self._orm_instance)
293-
for field_ in cls_fields:
294-
if field_ not in model_fields:
295-
continue
296-
value = getattr(self, field_, None)
297-
if value is None:
298-
continue
299-
if isinstance(value, list):
300-
obj_list = []
301-
for obj in value:
302-
obj_list.append(obj._orm_instance)
303-
getattr(self._orm_instance, field_).add(*obj_list)
304-
else:
305-
if isinstance(value, BaseModel): # relations
306-
try:
307-
value = value._orm_instance.__class__.objects.using(
308-
kwargs.get("using", "default")
309-
).get(guid=value.guid)
310-
except ObjectDoesNotExist:
311-
raise value.__class__.DoesNotExist
312-
# for all others
313-
setattr(self._orm_instance, field_, value)
314-
315-
self._orm_instance.save(**kwargs)
316-
self._saved = True
317-
318385
@classmethod
319386
@handle_field_errors
320387
def create(cls, **kwargs):
388+
target_db = kwargs.pop("using", "default")
321389
obj = cls(**kwargs)
322-
obj.save(force_insert=True)
390+
obj.save(force_insert=True, using=target_db)
323391
return obj
324392

325-
def delete(self, **kwargs):
326-
if not self._saved:
327-
raise self.__class__.NotSaved(extra_detail="Delete failed")
328-
count, _ = self._orm_instance.delete(**kwargs)
329-
self._saved = False
330-
return count
331-
332393
@classmethod
333394
@handle_field_errors
334395
def get(cls, **kwargs):
335396
try:
336-
orm_instance = cls._orm_model_cls.objects.get(**kwargs)
397+
orm_instance = cls._orm_model_cls.objects.using(kwargs.pop("using", "default")).get(
398+
**kwargs
399+
)
337400
except ObjectDoesNotExist:
338401
raise cls.DoesNotExist
339402
except MultipleObjectsReturned:
@@ -343,21 +406,30 @@ def get(cls, **kwargs):
343406

344407
@classmethod
345408
@handle_field_errors
346-
def filter(cls, **kwargs):
347-
qs = cls._orm_model_cls.objects.filter(**kwargs)
348-
return ObjectSet(_model=cls, _orm_model=cls._orm_model_cls, _orm_queryset=qs)
409+
def get_or_create(cls, **kwargs):
410+
try:
411+
return cls.get(**kwargs), False
412+
except cls.DoesNotExist:
413+
# Try to create an object using passed params.
414+
try:
415+
return cls.create(**kwargs), True
416+
except cls.IntegrityError:
417+
try:
418+
return cls.get(**kwargs), False
419+
except cls.DoesNotExist:
420+
pass
421+
raise
349422

350423
@classmethod
351424
@handle_field_errors
352-
def bulk_create(cls, **kwargs):
353-
objs = cls._orm_model_cls.objects.bulk_create(**kwargs)
354-
qs = cls._orm_model_cls.objects.filter(pk__in=[obj.pk for obj in objs])
425+
def filter(cls, **kwargs):
426+
qs = cls._orm_model_cls.objects.using(kwargs.pop("using", "default")).filter(**kwargs)
355427
return ObjectSet(_model=cls, _orm_model=cls._orm_model_cls, _orm_queryset=qs)
356428

357429
@classmethod
358430
@handle_field_errors
359-
def find(cls, query="", reverse=False, sort_tag="date"):
360-
qs = cls._orm_model_cls.find(query=query, reverse=reverse, sort_tag=sort_tag)
431+
def find(cls, query="", **kwargs):
432+
qs = cls._orm_model_cls.find(query=query, **kwargs)
361433
return ObjectSet(_model=cls, _orm_model=cls._orm_model_cls, _orm_queryset=qs)
362434

363435
def get_tags(self):
@@ -426,9 +498,13 @@ def saved(self):
426498
return self._saved
427499

428500
def delete(self):
501+
count = 0
502+
for obj in self._obj_set:
503+
obj.delete()
504+
count += 1
505+
self._orm_queryset.delete()
429506
self._obj_set = []
430507
self._saved = False
431-
count, _ = self._orm_queryset.delete()
432508
return count
433509

434510
def values_list(self, *fields, flat=False):

0 commit comments

Comments
 (0)