17
17
ObjectDoesNotExist ,
18
18
ValidationError ,
19
19
)
20
- from django .db import DatabaseError
20
+ from django .db import DataError
21
21
from django .db .models import Model , QuerySet
22
22
from django .db .models .base import subclass_exception
23
23
from django .db .models .manager import Manager
24
+ from django .db .utils import IntegrityError as DBIntegrityError
24
25
25
26
from ..exceptions import (
26
27
ADRException ,
28
+ IntegrityError ,
27
29
MultipleObjectsReturnedError ,
28
30
ObjectDoesNotExistError ,
29
31
ObjectNotSavedError ,
@@ -42,7 +44,7 @@ def handle_field_errors(func):
42
44
def wrapper (* args , ** kwargs ):
43
45
try :
44
46
return func (* args , ** kwargs )
45
- except (FieldError , FieldDoesNotExist , ValidationError , DatabaseError ) as e :
47
+ except (FieldError , FieldDoesNotExist , ValidationError , DataError ) as e :
46
48
raise ADRException (extra_detail = f"One or more fields set or accessed are invalid: { e } " )
47
49
48
50
return wrapper
@@ -97,6 +99,13 @@ def __new__(
97
99
parents ,
98
100
namespace .get ("__module__" ),
99
101
)
102
+ add_exception_to_cls (
103
+ "IntegrityError" ,
104
+ IntegrityError ,
105
+ new_cls ,
106
+ parents ,
107
+ namespace .get ("__module__" ),
108
+ )
100
109
# all classes must be dataclasses
101
110
new_cls = dataclass (eq = False , order = False , repr = False )(new_cls )
102
111
return new_cls
@@ -121,7 +130,7 @@ def __getattribute__(cls, name):
121
130
122
131
123
132
class BaseModel (metaclass = BaseMeta ):
124
- guid : UUID = field (init = False , compare = False , kw_only = True , default_factory = uuid .uuid1 )
133
+ guid : UUID = field (compare = False , kw_only = True , default_factory = uuid .uuid1 )
125
134
tags : str = field (compare = False , kw_only = True , default = "" )
126
135
_saved : bool = field (
127
136
init = False , compare = False , default = False
@@ -209,19 +218,108 @@ def _get_field_names(cls, with_types=False, include_private=False):
209
218
fields_ .append ((f .name , f .type ) if with_types else f .name )
210
219
return tuple (fields_ )
211
220
221
+ def _get_var_field_names (self , include_private = False ):
222
+ fields_ = []
223
+ for f in vars (self ).keys ():
224
+ if not include_private and f .startswith ("_" ):
225
+ continue
226
+ fields_ .append (f )
227
+ return tuple (fields_ )
228
+
212
229
@classmethod
213
- def _get_all_field_names (cls ):
230
+ def _get_prop_field_names (cls ):
214
231
"""Returns a list of all field names from a dataclass, including properties."""
215
232
property_fields = []
216
233
for name , value in inspect .getmembers (cls ):
217
234
if isinstance (value , property ):
218
235
property_fields .append (name )
219
- return tuple (property_fields ) + cls . _get_field_names ()
236
+ return tuple (property_fields )
220
237
221
238
@property
222
- def saved (self ):
239
+ def saved (self ) -> bool :
223
240
return self ._saved
224
241
242
+ @property
243
+ def _orm_saved (self ) -> bool :
244
+ return not self ._orm_instance ._state .adding
245
+
246
+ @property
247
+ def _orm_db (self ) -> str :
248
+ return self ._orm_instance ._state .db
249
+
250
+ @property
251
+ def db (self ):
252
+ return self ._orm_db
253
+
254
+ def as_dict (self ):
255
+ out_dict = {}
256
+ # use a combination of vars and fields
257
+ cls_fields = set (self ._get_field_names () + self ._get_var_field_names ())
258
+ for field_ in cls_fields :
259
+ if field_ .startswith ("_" ):
260
+ continue
261
+ value = getattr (self , field_ , None )
262
+ if value is None : # skip and use defaults
263
+ continue
264
+ out_dict [field_ ] = value
265
+ return out_dict
266
+
267
+ def _prepare_for_save (self , ** kwargs ):
268
+ self ._saved = False
269
+
270
+ target_db = kwargs .pop ("using" , "default" )
271
+ cls_fields = self ._get_field_names () + self ._get_prop_field_names ()
272
+ model_fields = self ._get_orm_field_names (self ._orm_instance )
273
+ for field_ in cls_fields :
274
+ if field_ not in model_fields :
275
+ continue
276
+ value = getattr (self , field_ , None )
277
+ if value is None : # skip and use defaults
278
+ continue
279
+ if isinstance (value , list ):
280
+ objs = [o ._orm_instance for o in value ]
281
+ getattr (self ._orm_instance , field_ ).add (* objs )
282
+ else :
283
+ if isinstance (value , BaseModel ): # relations
284
+ try :
285
+ value = value ._orm_instance .__class__ .objects .using (target_db ).get (
286
+ guid = value .guid
287
+ )
288
+ except ObjectDoesNotExist as e :
289
+ raise value .__class__ .DoesNotExist (
290
+ extra_detail = f"Object with guid '{ value .guid } '" f" does not exist: { e } "
291
+ )
292
+ # for all others
293
+ setattr (self ._orm_instance , field_ , value )
294
+
295
+ return self
296
+
297
+ def reinit (self ):
298
+ self ._orm_instance = self .__class__ ._orm_model_cls ()
299
+
300
+ @handle_field_errors
301
+ def save (self , ** kwargs ):
302
+ try :
303
+ obj = self ._prepare_for_save (** kwargs )
304
+ obj ._orm_instance .save (** kwargs )
305
+ except DBIntegrityError as e :
306
+ raise self .__class__ .IntegrityError (
307
+ extra_detail = f"Save failed for object with guid '{ self .guid } ': { e } "
308
+ )
309
+ except Exception as e :
310
+ raise e
311
+ else :
312
+ obj ._saved = True
313
+
314
+ def delete (self , ** kwargs ):
315
+ if not self ._saved :
316
+ raise self .__class__ .NotSaved (
317
+ extra_detail = f"Delete failed for object with guid '{ self .guid } '."
318
+ )
319
+ count , _ = self ._orm_instance .delete (** kwargs )
320
+ self ._saved = False
321
+ return count
322
+
225
323
@classmethod
226
324
def from_db (cls , orm_instance , parent = None ):
227
325
cls_fields = dict (cls ._get_field_names (with_types = True , include_private = True ))
@@ -284,56 +382,21 @@ def from_db(cls, orm_instance, parent=None):
284
382
obj ._saved = True
285
383
return obj
286
384
287
- @handle_field_errors
288
- def save (self , ** kwargs ):
289
- self ._saved = False # reset
290
-
291
- cls_fields = self ._get_all_field_names ()
292
- model_fields = self ._get_orm_field_names (self ._orm_instance )
293
- for field_ in cls_fields :
294
- if field_ not in model_fields :
295
- continue
296
- value = getattr (self , field_ , None )
297
- if value is None :
298
- continue
299
- if isinstance (value , list ):
300
- obj_list = []
301
- for obj in value :
302
- obj_list .append (obj ._orm_instance )
303
- getattr (self ._orm_instance , field_ ).add (* obj_list )
304
- else :
305
- if isinstance (value , BaseModel ): # relations
306
- try :
307
- value = value ._orm_instance .__class__ .objects .using (
308
- kwargs .get ("using" , "default" )
309
- ).get (guid = value .guid )
310
- except ObjectDoesNotExist :
311
- raise value .__class__ .DoesNotExist
312
- # for all others
313
- setattr (self ._orm_instance , field_ , value )
314
-
315
- self ._orm_instance .save (** kwargs )
316
- self ._saved = True
317
-
318
385
@classmethod
319
386
@handle_field_errors
320
387
def create (cls , ** kwargs ):
388
+ target_db = kwargs .pop ("using" , "default" )
321
389
obj = cls (** kwargs )
322
- obj .save (force_insert = True )
390
+ obj .save (force_insert = True , using = target_db )
323
391
return obj
324
392
325
- def delete (self , ** kwargs ):
326
- if not self ._saved :
327
- raise self .__class__ .NotSaved (extra_detail = "Delete failed" )
328
- count , _ = self ._orm_instance .delete (** kwargs )
329
- self ._saved = False
330
- return count
331
-
332
393
@classmethod
333
394
@handle_field_errors
334
395
def get (cls , ** kwargs ):
335
396
try :
336
- orm_instance = cls ._orm_model_cls .objects .get (** kwargs )
397
+ orm_instance = cls ._orm_model_cls .objects .using (kwargs .pop ("using" , "default" )).get (
398
+ ** kwargs
399
+ )
337
400
except ObjectDoesNotExist :
338
401
raise cls .DoesNotExist
339
402
except MultipleObjectsReturned :
@@ -343,21 +406,30 @@ def get(cls, **kwargs):
343
406
344
407
@classmethod
345
408
@handle_field_errors
346
- def filter (cls , ** kwargs ):
347
- qs = cls ._orm_model_cls .objects .filter (** kwargs )
348
- return ObjectSet (_model = cls , _orm_model = cls ._orm_model_cls , _orm_queryset = qs )
409
+ def get_or_create (cls , ** kwargs ):
410
+ try :
411
+ return cls .get (** kwargs ), False
412
+ except cls .DoesNotExist :
413
+ # Try to create an object using passed params.
414
+ try :
415
+ return cls .create (** kwargs ), True
416
+ except cls .IntegrityError :
417
+ try :
418
+ return cls .get (** kwargs ), False
419
+ except cls .DoesNotExist :
420
+ pass
421
+ raise
349
422
350
423
@classmethod
351
424
@handle_field_errors
352
- def bulk_create (cls , ** kwargs ):
353
- objs = cls ._orm_model_cls .objects .bulk_create (** kwargs )
354
- qs = cls ._orm_model_cls .objects .filter (pk__in = [obj .pk for obj in objs ])
425
+ def filter (cls , ** kwargs ):
426
+ qs = cls ._orm_model_cls .objects .using (kwargs .pop ("using" , "default" )).filter (** kwargs )
355
427
return ObjectSet (_model = cls , _orm_model = cls ._orm_model_cls , _orm_queryset = qs )
356
428
357
429
@classmethod
358
430
@handle_field_errors
359
- def find (cls , query = "" , reverse = False , sort_tag = "date" ):
360
- qs = cls ._orm_model_cls .find (query = query , reverse = reverse , sort_tag = sort_tag )
431
+ def find (cls , query = "" , ** kwargs ):
432
+ qs = cls ._orm_model_cls .find (query = query , ** kwargs )
361
433
return ObjectSet (_model = cls , _orm_model = cls ._orm_model_cls , _orm_queryset = qs )
362
434
363
435
def get_tags (self ):
@@ -426,9 +498,13 @@ def saved(self):
426
498
return self ._saved
427
499
428
500
def delete (self ):
501
+ count = 0
502
+ for obj in self ._obj_set :
503
+ obj .delete ()
504
+ count += 1
505
+ self ._orm_queryset .delete ()
429
506
self ._obj_set = []
430
507
self ._saved = False
431
- count , _ = self ._orm_queryset .delete ()
432
508
return count
433
509
434
510
def values_list (self , * fields , flat = False ):
0 commit comments