9
9
# This software is distributed under the 3-clause BSD License.
10
10
# ___________________________________________________________________________
11
11
12
- import copy
13
12
import enum
14
13
from io import StringIO
15
14
from math import inf
16
15
17
- from pyomo .common .collections import Bunch
16
+ from pyomo .common .collections import Bunch , Sequence , Mapping
18
17
19
18
20
19
class ScalarType (str , enum .Enum ):
@@ -35,17 +34,31 @@ def __str__(self):
35
34
36
35
37
36
default_print_options = Bunch (schema = False , ignore_time = False )
38
-
39
37
strict = False
40
38
41
39
42
40
class UndefinedData (object ):
41
+ singleton = {}
42
+
43
+ def __new__ (cls , name = 'undefined' ):
44
+ if name not in UndefinedData .singleton :
45
+ UndefinedData .singleton [name ] = super ().__new__ (cls )
46
+ UndefinedData .singleton [name ].name = name
47
+ return UndefinedData .singleton [name ]
48
+
49
+ def __deepcopy__ (self , memo ):
50
+ # Prevent deepcopy from duplicating this object
51
+ return self
52
+
53
+ def __reduce__ (self ):
54
+ return self .__class__ , (self .name ,)
55
+
43
56
def __str__ (self ):
44
- return "<undefined >"
57
+ return f"< { self . name } >"
45
58
46
59
47
- undefined = UndefinedData ()
48
- ignore = UndefinedData ()
60
+ undefined = UndefinedData ('undefined' )
61
+ ignore = UndefinedData ('ignore' )
49
62
50
63
51
64
class ScalarData (object ):
@@ -64,6 +77,10 @@ def __init__(
64
77
self .scalar_description = scalar_description
65
78
self .scalar_type = type
66
79
self ._required = required
80
+ self ._active = False
81
+
82
+ def __eq__ (self , other ):
83
+ return self .__dict__ == other .__dict__
67
84
68
85
def get_value (self ):
69
86
if isinstance (self .value , enum .Enum ):
@@ -109,9 +126,9 @@ def pprint(self, ostream, option, prefix="", repn=None):
109
126
110
127
value = self .yaml_fix (self .get_value ())
111
128
112
- if value is inf :
129
+ if value == inf :
113
130
value = '.inf'
114
- elif value is - inf :
131
+ elif value == - inf :
115
132
value = '-.inf'
116
133
117
134
if not option .schema and self .description is None and self .units is None :
@@ -149,8 +166,8 @@ def yaml_fix(self, val):
149
166
150
167
def load (self , repn ):
151
168
if type (repn ) is dict :
152
- for key in repn :
153
- setattr (self , key , repn [ key ] )
169
+ for key , val in repn . items () :
170
+ setattr (self , key , val )
154
171
else :
155
172
self .value = repn
156
173
@@ -167,12 +184,15 @@ def __init__(self, cls):
167
184
168
185
def __len__ (self ):
169
186
if '_list' in self .__dict__ :
170
- return len (self .__dict__ [ ' _list' ] )
187
+ return len (self ._list )
171
188
return 0
172
189
173
190
def __getitem__ (self , i ):
174
191
return self ._list [i ]
175
192
193
+ def __eq__ (self , other ):
194
+ return self .__dict__ == other .__dict__
195
+
176
196
def clear (self ):
177
197
self ._list = []
178
198
@@ -183,21 +203,15 @@ def __call__(self, i=0):
183
203
return self ._list [i ]
184
204
185
205
def __getattr__ (self , name ):
186
- try :
187
- return self .__dict__ [name ]
188
- except :
189
- pass
206
+ if name [0 ] == "_" :
207
+ super ().__getattr__ (name )
190
208
if len (self ) == 0 :
191
209
self .add ()
192
210
return getattr (self ._list [0 ], name )
193
211
194
212
def __setattr__ (self , name , val ):
195
- if name == "__class__" :
196
- self .__class__ = val
197
- return
198
213
if name [0 ] == "_" :
199
- self .__dict__ [name ] = val
200
- return
214
+ return super ().__setattr__ (name , val )
201
215
if len (self ) == 0 :
202
216
self .add ()
203
217
setattr (self ._list [0 ], name , val )
@@ -239,16 +253,10 @@ def load(self, repn):
239
253
item = self .add ()
240
254
item .load (data )
241
255
242
- def __getstate__ (self ):
243
- return copy .copy (self .__dict__ )
244
-
245
- def __setstate__ (self , state ):
246
- self .__dict__ .update (state )
247
-
248
256
def __str__ (self ):
249
257
ostream = StringIO ()
250
258
option = default_print_options
251
- self .pprint (ostream , self . _option , repn = self ._repn_ (self . _option ))
259
+ self .pprint (ostream , option , repn = self ._repn_ (option ))
252
260
return ostream .getvalue ()
253
261
254
262
@@ -259,41 +267,21 @@ def __str__(self):
259
267
# first letter is capitalized.
260
268
#
261
269
class MapContainer (dict ):
262
- def __getnewargs_ex__ (self ):
263
- # Pass arguments to __new__ when unpickling
264
- return ((0 , 0 ), {})
265
-
266
- def __getnewargs__ (self ):
267
- # Pass arguments to __new__ when unpickling
268
- return (0 , 0 )
269
-
270
- def __new__ (cls , * args , ** kwargs ):
271
- #
272
- # If the user provides "too many" arguments, then
273
- # pre-initialize the '_order' attribute. This pre-initializes
274
- # the class during unpickling.
275
- #
276
- _instance = super (MapContainer , cls ).__new__ (cls , * args , ** kwargs )
277
- if len (args ) > 1 :
278
- super (MapContainer , _instance ).__setattr__ ('_order' , [])
279
- return _instance
280
270
281
271
def __init__ (self , ordered = False ):
282
- dict .__init__ (self )
272
+ super () .__init__ ()
283
273
self ._active = True
284
274
self ._required = False
285
- self ._ordered = ordered
286
- self ._order = []
287
275
self ._option = default_print_options
288
276
289
- def keys (self ):
290
- return self ._order
277
+ def __eq__ (self , other ):
278
+ # We need to check both our __dict__ (local attributes) and the
279
+ # underlying dict data (which doesn't show up in the __dict__).
280
+ # So we will use the base __eq__ in addition to checking
281
+ # __dict__.
282
+ return super ().__eq__ (other ) and self .__dict__ == other .__dict__
291
283
292
284
def __getattr__ (self , name ):
293
- try :
294
- return self .__dict__ [name ]
295
- except :
296
- pass
297
285
try :
298
286
self ._active = True
299
287
return self [self ._convert (name )]
@@ -307,12 +295,8 @@ def __getattr__(self, name):
307
295
)
308
296
309
297
def __setattr__ (self , name , val ):
310
- if name == "__class__" :
311
- self .__class__ = val
312
- return
313
298
if name [0 ] == "_" :
314
- self .__dict__ [name ] = val
315
- return
299
+ return super ().__setattr__ (name , val )
316
300
self ._active = True
317
301
tmp = self ._convert (name )
318
302
if tmp not in self :
@@ -341,12 +325,18 @@ def __setitem__(self, name, val):
341
325
self ._set_value (tmp , val )
342
326
343
327
def _set_value (self , name , val ):
344
- if isinstance (val , ListContainer ) or isinstance ( val , MapContainer ):
345
- dict .__setitem__ (self , name , val )
328
+ if isinstance (val , ( ListContainer , MapContainer ) ):
329
+ super () .__setitem__ (name , val )
346
330
elif isinstance (val , ScalarData ):
347
- dict .__getitem__ (self , name ).value = val .value
331
+ data = super ().__getitem__ (name )
332
+ data .value = val .value
333
+ data ._active = val ._active
334
+ data ._required = val ._required
335
+ data .scalar_type = val .scalar_type
348
336
else :
349
- dict .__getitem__ (self , name ).value = val
337
+ data = super ().__getitem__ (name )
338
+ data .value = val
339
+ data ._active = True
350
340
351
341
def __getitem__ (self , name ):
352
342
tmp = self ._convert (name )
@@ -357,25 +347,21 @@ def __getitem__(self, name):
357
347
+ "' for object with type "
358
348
+ str (type (self ))
359
349
)
360
- item = dict .__getitem__ (self , tmp )
361
- if isinstance (item , ListContainer ) or isinstance ( item , MapContainer ):
350
+ item = super () .__getitem__ (tmp )
351
+ if isinstance (item , ( ListContainer , MapContainer ) ):
362
352
return item
363
353
return item .value
364
354
365
355
def declare (self , name , ** kwds ):
366
356
if name in self or type (name ) is int :
367
357
return
368
- tmp = self ._convert (name )
369
- self ._order .append (tmp )
370
- if 'value' in kwds and (
371
- isinstance (kwds ['value' ], MapContainer )
372
- or isinstance (kwds ['value' ], ListContainer )
373
- ):
358
+ data = kwds .get ('value' , None )
359
+ if isinstance (data , (MapContainer , ListContainer )):
374
360
if 'active' in kwds :
375
- kwds [ 'value' ] ._active = kwds ['active' ]
361
+ data ._active = kwds ['active' ]
376
362
if 'required' in kwds and kwds ['required' ] is True :
377
- kwds [ 'value' ] ._required = True
378
- dict .__setitem__ (self , tmp , kwds [ 'value' ] )
363
+ data ._required = True
364
+ super () .__setitem__ (self . _convert ( name ), data )
379
365
else :
380
366
data = ScalarData (** kwds )
381
367
if 'required' in kwds and kwds ['required' ] is True :
@@ -387,23 +373,16 @@ def declare(self, name, **kwds):
387
373
#
388
374
# if 'value' in kwds:
389
375
# data._default = kwds['value']
390
- dict .__setitem__ (self , tmp , data )
376
+ super () .__setitem__ (self . _convert ( name ) , data )
391
377
392
378
def _repn_ (self , option ):
393
379
if not option .schema and not self ._active and not self ._required :
394
380
return ignore
395
- if self ._ordered :
396
- tmp = []
397
- for key in self ._order :
398
- rep = dict .__getitem__ (self , key )._repn_ (option )
399
- if not rep == ignore :
400
- tmp .append ({key : rep })
401
- else :
402
- tmp = {}
403
- for key in self .keys ():
404
- rep = dict .__getitem__ (self , key )._repn_ (option )
405
- if not rep == ignore :
406
- tmp [key ] = rep
381
+ tmp = {}
382
+ for key , val in self .items ():
383
+ rep = val ._repn_ (option )
384
+ if not rep == ignore :
385
+ tmp [key ] = rep
407
386
return tmp
408
387
409
388
def _convert (self , name ):
@@ -417,7 +396,6 @@ def __repr__(self):
417
396
418
397
def __str__ (self ):
419
398
ostream = StringIO ()
420
- option = default_print_options
421
399
self .pprint (ostream , self ._option , repn = self ._repn_ (self ._option ))
422
400
return ostream .getvalue ()
423
401
@@ -427,10 +405,9 @@ def pprint(self, ostream, option, from_list=False, prefix="", repn=None):
427
405
else :
428
406
_prefix = prefix
429
407
ostream .write ('\n ' )
430
- for key in self ._order :
408
+ for key , item in self .items () :
431
409
if not key in repn :
432
410
continue
433
- item = dict .__getitem__ (self , key )
434
411
ostream .write (_prefix + key + ": " )
435
412
_prefix = prefix
436
413
if isinstance (item , ListContainer ):
@@ -439,46 +416,16 @@ def pprint(self, ostream, option, from_list=False, prefix="", repn=None):
439
416
item .pprint (ostream , option , prefix = _prefix + " " , repn = repn [key ])
440
417
441
418
def load (self , repn ):
442
- for key in repn :
419
+ for key , val in repn . items () :
443
420
tmp = self ._convert (key )
444
421
if tmp not in self :
445
422
self .declare (tmp )
446
- item = dict .__getitem__ (self , tmp )
423
+ item = super () .__getitem__ (tmp )
447
424
item ._active = True
448
- item .load (repn [key ])
449
-
450
- def __getnewargs__ (self ):
451
- return (False , False )
452
-
453
- def __getstate__ (self ):
454
- return copy .copy (self .__dict__ )
455
-
456
- def __setstate__ (self , state ):
457
- self .__dict__ .update (state )
458
-
459
-
460
- if __name__ == '__main__' :
461
- d = MapContainer ()
462
- d .declare ('f' )
463
- d .declare ('g' )
464
- d .declare ('h' )
465
- d .declare ('i' , value = ListContainer (UndefinedData ))
466
- d .declare ('j' , value = ListContainer (UndefinedData ), active = False )
467
- print ("X" )
468
- d .f = 1
469
- print ("Y" )
470
- print (d .f )
471
- print (d .keys ())
472
- d .g = None
473
- print (d .keys ())
474
- try :
475
- print (d .f , d .g , d .h )
476
- except :
477
- pass
478
- d ['h' ] = None
479
- print ("" )
480
- print ("FINAL" )
481
- print (d .f , d .g , d .h , d .i , d .j )
482
- print (d .i ._active , d .j ._active )
483
- d .j .add ()
484
- print (d .i ._active , d .j ._active )
425
+ item .load (val )
426
+
427
+
428
+ # Register these as sequence / mapping types (so things like
429
+ # assertStructuredAlmostEqual will process them correctly)
430
+ Sequence .register (ListContainer )
431
+ Mapping .register (MapContainer )
0 commit comments