Skip to content

Commit 5bd2849

Browse files
authored
Merge pull request #3597 from jsiirola/deepcopy-results
Resolve issues copying and pickling `SolverResults`
2 parents 5ee4159 + d6a9079 commit 5bd2849

File tree

11 files changed

+707
-553
lines changed

11 files changed

+707
-553
lines changed

pyomo/opt/results/container.py

Lines changed: 76 additions & 129 deletions
Original file line numberDiff line numberDiff line change
@@ -9,12 +9,11 @@
99
# This software is distributed under the 3-clause BSD License.
1010
# ___________________________________________________________________________
1111

12-
import copy
1312
import enum
1413
from io import StringIO
1514
from math import inf
1615

17-
from pyomo.common.collections import Bunch
16+
from pyomo.common.collections import Bunch, Sequence, Mapping
1817

1918

2019
class ScalarType(str, enum.Enum):
@@ -35,17 +34,31 @@ def __str__(self):
3534

3635

3736
default_print_options = Bunch(schema=False, ignore_time=False)
38-
3937
strict = False
4038

4139

4240
class UndefinedData(object):
41+
singleton = {}
42+
43+
def __new__(cls, name='undefined'):
44+
if name not in UndefinedData.singleton:
45+
UndefinedData.singleton[name] = super().__new__(cls)
46+
UndefinedData.singleton[name].name = name
47+
return UndefinedData.singleton[name]
48+
49+
def __deepcopy__(self, memo):
50+
# Prevent deepcopy from duplicating this object
51+
return self
52+
53+
def __reduce__(self):
54+
return self.__class__, (self.name,)
55+
4356
def __str__(self):
44-
return "<undefined>"
57+
return f"<{self.name}>"
4558

4659

47-
undefined = UndefinedData()
48-
ignore = UndefinedData()
60+
undefined = UndefinedData('undefined')
61+
ignore = UndefinedData('ignore')
4962

5063

5164
class ScalarData(object):
@@ -64,6 +77,10 @@ def __init__(
6477
self.scalar_description = scalar_description
6578
self.scalar_type = type
6679
self._required = required
80+
self._active = False
81+
82+
def __eq__(self, other):
83+
return self.__dict__ == other.__dict__
6784

6885
def get_value(self):
6986
if isinstance(self.value, enum.Enum):
@@ -109,9 +126,9 @@ def pprint(self, ostream, option, prefix="", repn=None):
109126

110127
value = self.yaml_fix(self.get_value())
111128

112-
if value is inf:
129+
if value == inf:
113130
value = '.inf'
114-
elif value is -inf:
131+
elif value == -inf:
115132
value = '-.inf'
116133

117134
if not option.schema and self.description is None and self.units is None:
@@ -149,8 +166,8 @@ def yaml_fix(self, val):
149166

150167
def load(self, repn):
151168
if type(repn) is dict:
152-
for key in repn:
153-
setattr(self, key, repn[key])
169+
for key, val in repn.items():
170+
setattr(self, key, val)
154171
else:
155172
self.value = repn
156173

@@ -167,12 +184,15 @@ def __init__(self, cls):
167184

168185
def __len__(self):
169186
if '_list' in self.__dict__:
170-
return len(self.__dict__['_list'])
187+
return len(self._list)
171188
return 0
172189

173190
def __getitem__(self, i):
174191
return self._list[i]
175192

193+
def __eq__(self, other):
194+
return self.__dict__ == other.__dict__
195+
176196
def clear(self):
177197
self._list = []
178198

@@ -183,21 +203,15 @@ def __call__(self, i=0):
183203
return self._list[i]
184204

185205
def __getattr__(self, name):
186-
try:
187-
return self.__dict__[name]
188-
except:
189-
pass
206+
if name[0] == "_":
207+
super().__getattr__(name)
190208
if len(self) == 0:
191209
self.add()
192210
return getattr(self._list[0], name)
193211

194212
def __setattr__(self, name, val):
195-
if name == "__class__":
196-
self.__class__ = val
197-
return
198213
if name[0] == "_":
199-
self.__dict__[name] = val
200-
return
214+
return super().__setattr__(name, val)
201215
if len(self) == 0:
202216
self.add()
203217
setattr(self._list[0], name, val)
@@ -239,16 +253,10 @@ def load(self, repn):
239253
item = self.add()
240254
item.load(data)
241255

242-
def __getstate__(self):
243-
return copy.copy(self.__dict__)
244-
245-
def __setstate__(self, state):
246-
self.__dict__.update(state)
247-
248256
def __str__(self):
249257
ostream = StringIO()
250258
option = default_print_options
251-
self.pprint(ostream, self._option, repn=self._repn_(self._option))
259+
self.pprint(ostream, option, repn=self._repn_(option))
252260
return ostream.getvalue()
253261

254262

@@ -259,41 +267,21 @@ def __str__(self):
259267
# first letter is capitalized.
260268
#
261269
class MapContainer(dict):
262-
def __getnewargs_ex__(self):
263-
# Pass arguments to __new__ when unpickling
264-
return ((0, 0), {})
265-
266-
def __getnewargs__(self):
267-
# Pass arguments to __new__ when unpickling
268-
return (0, 0)
269-
270-
def __new__(cls, *args, **kwargs):
271-
#
272-
# If the user provides "too many" arguments, then
273-
# pre-initialize the '_order' attribute. This pre-initializes
274-
# the class during unpickling.
275-
#
276-
_instance = super(MapContainer, cls).__new__(cls, *args, **kwargs)
277-
if len(args) > 1:
278-
super(MapContainer, _instance).__setattr__('_order', [])
279-
return _instance
280270

281271
def __init__(self, ordered=False):
282-
dict.__init__(self)
272+
super().__init__()
283273
self._active = True
284274
self._required = False
285-
self._ordered = ordered
286-
self._order = []
287275
self._option = default_print_options
288276

289-
def keys(self):
290-
return self._order
277+
def __eq__(self, other):
278+
# We need to check both our __dict__ (local attributes) and the
279+
# underlying dict data (which doesn't show up in the __dict__).
280+
# So we will use the base __eq__ in addition to checking
281+
# __dict__.
282+
return super().__eq__(other) and self.__dict__ == other.__dict__
291283

292284
def __getattr__(self, name):
293-
try:
294-
return self.__dict__[name]
295-
except:
296-
pass
297285
try:
298286
self._active = True
299287
return self[self._convert(name)]
@@ -307,12 +295,8 @@ def __getattr__(self, name):
307295
)
308296

309297
def __setattr__(self, name, val):
310-
if name == "__class__":
311-
self.__class__ = val
312-
return
313298
if name[0] == "_":
314-
self.__dict__[name] = val
315-
return
299+
return super().__setattr__(name, val)
316300
self._active = True
317301
tmp = self._convert(name)
318302
if tmp not in self:
@@ -341,12 +325,18 @@ def __setitem__(self, name, val):
341325
self._set_value(tmp, val)
342326

343327
def _set_value(self, name, val):
344-
if isinstance(val, ListContainer) or isinstance(val, MapContainer):
345-
dict.__setitem__(self, name, val)
328+
if isinstance(val, (ListContainer, MapContainer)):
329+
super().__setitem__(name, val)
346330
elif isinstance(val, ScalarData):
347-
dict.__getitem__(self, name).value = val.value
331+
data = super().__getitem__(name)
332+
data.value = val.value
333+
data._active = val._active
334+
data._required = val._required
335+
data.scalar_type = val.scalar_type
348336
else:
349-
dict.__getitem__(self, name).value = val
337+
data = super().__getitem__(name)
338+
data.value = val
339+
data._active = True
350340

351341
def __getitem__(self, name):
352342
tmp = self._convert(name)
@@ -357,25 +347,21 @@ def __getitem__(self, name):
357347
+ "' for object with type "
358348
+ str(type(self))
359349
)
360-
item = dict.__getitem__(self, tmp)
361-
if isinstance(item, ListContainer) or isinstance(item, MapContainer):
350+
item = super().__getitem__(tmp)
351+
if isinstance(item, (ListContainer, MapContainer)):
362352
return item
363353
return item.value
364354

365355
def declare(self, name, **kwds):
366356
if name in self or type(name) is int:
367357
return
368-
tmp = self._convert(name)
369-
self._order.append(tmp)
370-
if 'value' in kwds and (
371-
isinstance(kwds['value'], MapContainer)
372-
or isinstance(kwds['value'], ListContainer)
373-
):
358+
data = kwds.get('value', None)
359+
if isinstance(data, (MapContainer, ListContainer)):
374360
if 'active' in kwds:
375-
kwds['value']._active = kwds['active']
361+
data._active = kwds['active']
376362
if 'required' in kwds and kwds['required'] is True:
377-
kwds['value']._required = True
378-
dict.__setitem__(self, tmp, kwds['value'])
363+
data._required = True
364+
super().__setitem__(self._convert(name), data)
379365
else:
380366
data = ScalarData(**kwds)
381367
if 'required' in kwds and kwds['required'] is True:
@@ -387,23 +373,16 @@ def declare(self, name, **kwds):
387373
#
388374
# if 'value' in kwds:
389375
# data._default = kwds['value']
390-
dict.__setitem__(self, tmp, data)
376+
super().__setitem__(self._convert(name), data)
391377

392378
def _repn_(self, option):
393379
if not option.schema and not self._active and not self._required:
394380
return ignore
395-
if self._ordered:
396-
tmp = []
397-
for key in self._order:
398-
rep = dict.__getitem__(self, key)._repn_(option)
399-
if not rep == ignore:
400-
tmp.append({key: rep})
401-
else:
402-
tmp = {}
403-
for key in self.keys():
404-
rep = dict.__getitem__(self, key)._repn_(option)
405-
if not rep == ignore:
406-
tmp[key] = rep
381+
tmp = {}
382+
for key, val in self.items():
383+
rep = val._repn_(option)
384+
if not rep == ignore:
385+
tmp[key] = rep
407386
return tmp
408387

409388
def _convert(self, name):
@@ -417,7 +396,6 @@ def __repr__(self):
417396

418397
def __str__(self):
419398
ostream = StringIO()
420-
option = default_print_options
421399
self.pprint(ostream, self._option, repn=self._repn_(self._option))
422400
return ostream.getvalue()
423401

@@ -427,10 +405,9 @@ def pprint(self, ostream, option, from_list=False, prefix="", repn=None):
427405
else:
428406
_prefix = prefix
429407
ostream.write('\n')
430-
for key in self._order:
408+
for key, item in self.items():
431409
if not key in repn:
432410
continue
433-
item = dict.__getitem__(self, key)
434411
ostream.write(_prefix + key + ": ")
435412
_prefix = prefix
436413
if isinstance(item, ListContainer):
@@ -439,46 +416,16 @@ def pprint(self, ostream, option, from_list=False, prefix="", repn=None):
439416
item.pprint(ostream, option, prefix=_prefix + " ", repn=repn[key])
440417

441418
def load(self, repn):
442-
for key in repn:
419+
for key, val in repn.items():
443420
tmp = self._convert(key)
444421
if tmp not in self:
445422
self.declare(tmp)
446-
item = dict.__getitem__(self, tmp)
423+
item = super().__getitem__(tmp)
447424
item._active = True
448-
item.load(repn[key])
449-
450-
def __getnewargs__(self):
451-
return (False, False)
452-
453-
def __getstate__(self):
454-
return copy.copy(self.__dict__)
455-
456-
def __setstate__(self, state):
457-
self.__dict__.update(state)
458-
459-
460-
if __name__ == '__main__':
461-
d = MapContainer()
462-
d.declare('f')
463-
d.declare('g')
464-
d.declare('h')
465-
d.declare('i', value=ListContainer(UndefinedData))
466-
d.declare('j', value=ListContainer(UndefinedData), active=False)
467-
print("X")
468-
d.f = 1
469-
print("Y")
470-
print(d.f)
471-
print(d.keys())
472-
d.g = None
473-
print(d.keys())
474-
try:
475-
print(d.f, d.g, d.h)
476-
except:
477-
pass
478-
d['h'] = None
479-
print("")
480-
print("FINAL")
481-
print(d.f, d.g, d.h, d.i, d.j)
482-
print(d.i._active, d.j._active)
483-
d.j.add()
484-
print(d.i._active, d.j._active)
425+
item.load(val)
426+
427+
428+
# Register these as sequence / mapping types (so things like
429+
# assertStructuredAlmostEqual will process them correctly)
430+
Sequence.register(ListContainer)
431+
Mapping.register(MapContainer)

pyomo/opt/results/problem.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ def __str__(self):
2424

2525
class ProblemInformation(MapContainer):
2626
def __init__(self):
27-
MapContainer.__init__(self)
27+
super().__init__()
2828
self.declare('name')
2929
self.declare('lower_bound', value=float('-inf'))
3030
self.declare('upper_bound', value=float('inf'))

0 commit comments

Comments
 (0)