1717 ObjectDoesNotExist ,
1818 ValidationError ,
1919)
20- from django .db import DatabaseError
20+ from django .db import DataError
2121from django .db .models import Model , QuerySet
2222from django .db .models .base import subclass_exception
2323from django .db .models .manager import Manager
24+ from django .db .utils import IntegrityError as DBIntegrityError
2425
2526from ..exceptions import (
2627 ADRException ,
28+ IntegrityError ,
2729 MultipleObjectsReturnedError ,
2830 ObjectDoesNotExistError ,
2931 ObjectNotSavedError ,
@@ -42,7 +44,7 @@ def handle_field_errors(func):
4244 def wrapper (* args , ** kwargs ):
4345 try :
4446 return func (* args , ** kwargs )
45- except (FieldError , FieldDoesNotExist , ValidationError , DatabaseError ) as e :
47+ except (FieldError , FieldDoesNotExist , ValidationError , DataError ) as e :
4648 raise ADRException (extra_detail = f"One or more fields set or accessed are invalid: { e } " )
4749
4850 return wrapper
@@ -97,6 +99,13 @@ def __new__(
9799 parents ,
98100 namespace .get ("__module__" ),
99101 )
102+ add_exception_to_cls (
103+ "IntegrityError" ,
104+ IntegrityError ,
105+ new_cls ,
106+ parents ,
107+ namespace .get ("__module__" ),
108+ )
100109 # all classes must be dataclasses
101110 new_cls = dataclass (eq = False , order = False , repr = False )(new_cls )
102111 return new_cls
@@ -121,7 +130,7 @@ def __getattribute__(cls, name):
121130
122131
123132class BaseModel (metaclass = BaseMeta ):
124- guid : UUID = field (init = False , compare = False , kw_only = True , default_factory = uuid .uuid1 )
133+ guid : UUID = field (compare = False , kw_only = True , default_factory = uuid .uuid1 )
125134 tags : str = field (compare = False , kw_only = True , default = "" )
126135 _saved : bool = field (
127136 init = False , compare = False , default = False
@@ -209,19 +218,108 @@ def _get_field_names(cls, with_types=False, include_private=False):
209218 fields_ .append ((f .name , f .type ) if with_types else f .name )
210219 return tuple (fields_ )
211220
221+ def _get_var_field_names (self , include_private = False ):
222+ fields_ = []
223+ for f in vars (self ).keys ():
224+ if not include_private and f .startswith ("_" ):
225+ continue
226+ fields_ .append (f )
227+ return tuple (fields_ )
228+
212229 @classmethod
213- def _get_all_field_names (cls ):
230+ def _get_prop_field_names (cls ):
214231 """Returns a list of all field names from a dataclass, including properties."""
215232 property_fields = []
216233 for name , value in inspect .getmembers (cls ):
217234 if isinstance (value , property ):
218235 property_fields .append (name )
219- return tuple (property_fields ) + cls . _get_field_names ()
236+ return tuple (property_fields )
220237
221238 @property
222- def saved (self ):
239+ def saved (self ) -> bool :
223240 return self ._saved
224241
242+ @property
243+ def _orm_saved (self ) -> bool :
244+ return not self ._orm_instance ._state .adding
245+
246+ @property
247+ def _orm_db (self ) -> str :
248+ return self ._orm_instance ._state .db
249+
250+ @property
251+ def db (self ):
252+ return self ._orm_db
253+
254+ def as_dict (self ):
255+ out_dict = {}
256+ # use a combination of vars and fields
257+ cls_fields = set (self ._get_field_names () + self ._get_var_field_names ())
258+ for field_ in cls_fields :
259+ if field_ .startswith ("_" ):
260+ continue
261+ value = getattr (self , field_ , None )
262+ if value is None : # skip and use defaults
263+ continue
264+ out_dict [field_ ] = value
265+ return out_dict
266+
267+ def _prepare_for_save (self , ** kwargs ):
268+ self ._saved = False
269+
270+ target_db = kwargs .pop ("using" , "default" )
271+ cls_fields = self ._get_field_names () + self ._get_prop_field_names ()
272+ model_fields = self ._get_orm_field_names (self ._orm_instance )
273+ for field_ in cls_fields :
274+ if field_ not in model_fields :
275+ continue
276+ value = getattr (self , field_ , None )
277+ if value is None : # skip and use defaults
278+ continue
279+ if isinstance (value , list ):
280+ objs = [o ._orm_instance for o in value ]
281+ getattr (self ._orm_instance , field_ ).add (* objs )
282+ else :
283+ if isinstance (value , BaseModel ): # relations
284+ try :
285+ value = value ._orm_instance .__class__ .objects .using (target_db ).get (
286+ guid = value .guid
287+ )
288+ except ObjectDoesNotExist as e :
289+ raise value .__class__ .DoesNotExist (
290+ extra_detail = f"Object with guid '{ value .guid } '" f" does not exist: { e } "
291+ )
292+ # for all others
293+ setattr (self ._orm_instance , field_ , value )
294+
295+ return self
296+
297+ def reinit (self ):
298+ self ._orm_instance = self .__class__ ._orm_model_cls ()
299+
300+ @handle_field_errors
301+ def save (self , ** kwargs ):
302+ try :
303+ obj = self ._prepare_for_save (** kwargs )
304+ obj ._orm_instance .save (** kwargs )
305+ except DBIntegrityError as e :
306+ raise self .__class__ .IntegrityError (
307+ extra_detail = f"Save failed for object with guid '{ self .guid } ': { e } "
308+ )
309+ except Exception as e :
310+ raise e
311+ else :
312+ obj ._saved = True
313+
314+ def delete (self , ** kwargs ):
315+ if not self ._saved :
316+ raise self .__class__ .NotSaved (
317+ extra_detail = f"Delete failed for object with guid '{ self .guid } '."
318+ )
319+ count , _ = self ._orm_instance .delete (** kwargs )
320+ self ._saved = False
321+ return count
322+
225323 @classmethod
226324 def from_db (cls , orm_instance , parent = None ):
227325 cls_fields = dict (cls ._get_field_names (with_types = True , include_private = True ))
@@ -284,56 +382,21 @@ def from_db(cls, orm_instance, parent=None):
284382 obj ._saved = True
285383 return obj
286384
287- @handle_field_errors
288- def save (self , ** kwargs ):
289- self ._saved = False # reset
290-
291- cls_fields = self ._get_all_field_names ()
292- model_fields = self ._get_orm_field_names (self ._orm_instance )
293- for field_ in cls_fields :
294- if field_ not in model_fields :
295- continue
296- value = getattr (self , field_ , None )
297- if value is None :
298- continue
299- if isinstance (value , list ):
300- obj_list = []
301- for obj in value :
302- obj_list .append (obj ._orm_instance )
303- getattr (self ._orm_instance , field_ ).add (* obj_list )
304- else :
305- if isinstance (value , BaseModel ): # relations
306- try :
307- value = value ._orm_instance .__class__ .objects .using (
308- kwargs .get ("using" , "default" )
309- ).get (guid = value .guid )
310- except ObjectDoesNotExist :
311- raise value .__class__ .DoesNotExist
312- # for all others
313- setattr (self ._orm_instance , field_ , value )
314-
315- self ._orm_instance .save (** kwargs )
316- self ._saved = True
317-
318385 @classmethod
319386 @handle_field_errors
320387 def create (cls , ** kwargs ):
388+ target_db = kwargs .pop ("using" , "default" )
321389 obj = cls (** kwargs )
322- obj .save (force_insert = True )
390+ obj .save (force_insert = True , using = target_db )
323391 return obj
324392
325- def delete (self , ** kwargs ):
326- if not self ._saved :
327- raise self .__class__ .NotSaved (extra_detail = "Delete failed" )
328- count , _ = self ._orm_instance .delete (** kwargs )
329- self ._saved = False
330- return count
331-
332393 @classmethod
333394 @handle_field_errors
334395 def get (cls , ** kwargs ):
335396 try :
336- orm_instance = cls ._orm_model_cls .objects .get (** kwargs )
397+ orm_instance = cls ._orm_model_cls .objects .using (kwargs .pop ("using" , "default" )).get (
398+ ** kwargs
399+ )
337400 except ObjectDoesNotExist :
338401 raise cls .DoesNotExist
339402 except MultipleObjectsReturned :
@@ -343,21 +406,30 @@ def get(cls, **kwargs):
343406
344407 @classmethod
345408 @handle_field_errors
346- def filter (cls , ** kwargs ):
347- qs = cls ._orm_model_cls .objects .filter (** kwargs )
348- return ObjectSet (_model = cls , _orm_model = cls ._orm_model_cls , _orm_queryset = qs )
409+ def get_or_create (cls , ** kwargs ):
410+ try :
411+ return cls .get (** kwargs ), False
412+ except cls .DoesNotExist :
413+ # Try to create an object using passed params.
414+ try :
415+ return cls .create (** kwargs ), True
416+ except cls .IntegrityError :
417+ try :
418+ return cls .get (** kwargs ), False
419+ except cls .DoesNotExist :
420+ pass
421+ raise
349422
350423 @classmethod
351424 @handle_field_errors
352- def bulk_create (cls , ** kwargs ):
353- objs = cls ._orm_model_cls .objects .bulk_create (** kwargs )
354- qs = cls ._orm_model_cls .objects .filter (pk__in = [obj .pk for obj in objs ])
425+ def filter (cls , ** kwargs ):
426+ qs = cls ._orm_model_cls .objects .using (kwargs .pop ("using" , "default" )).filter (** kwargs )
355427 return ObjectSet (_model = cls , _orm_model = cls ._orm_model_cls , _orm_queryset = qs )
356428
357429 @classmethod
358430 @handle_field_errors
359- def find (cls , query = "" , reverse = False , sort_tag = "date" ):
360- qs = cls ._orm_model_cls .find (query = query , reverse = reverse , sort_tag = sort_tag )
431+ def find (cls , query = "" , ** kwargs ):
432+ qs = cls ._orm_model_cls .find (query = query , ** kwargs )
361433 return ObjectSet (_model = cls , _orm_model = cls ._orm_model_cls , _orm_queryset = qs )
362434
363435 def get_tags (self ):
@@ -426,9 +498,13 @@ def saved(self):
426498 return self ._saved
427499
428500 def delete (self ):
501+ count = 0
502+ for obj in self ._obj_set :
503+ obj .delete ()
504+ count += 1
505+ self ._orm_queryset .delete ()
429506 self ._obj_set = []
430507 self ._saved = False
431- count , _ = self ._orm_queryset .delete ()
432508 return count
433509
434510 def values_list (self , * fields , flat = False ):
0 commit comments