Skip to content

Commit

Permalink
Add __setstate__() to bootstrap old checkpoints
Browse files Browse the repository at this point in the history
  • Loading branch information
oxtopus committed Apr 2, 2015
1 parent f696ca3 commit bfe0e5a
Show file tree
Hide file tree
Showing 3 changed files with 12 additions and 5 deletions.
1 change: 0 additions & 1 deletion MANIFEST.in
Original file line number Diff line number Diff line change
Expand Up @@ -13,4 +13,3 @@ recursive-include external/linux64 swig *.a

recursive-include nupic/datafiles *.csv *.txt
recursive-include *.capnp

14 changes: 11 additions & 3 deletions nupic/encoders/random_distributed_scalar.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,8 @@
import math
import numbers
import pprint
import sys

import capnp
import numpy

from nupic.data import SENTINEL_VALUE_FOR_MISSING_DATA
Expand Down Expand Up @@ -160,6 +160,16 @@ def __init__(self, resolution, w=21, n=400, name=None, offset=None,
self.dump()


def __setstate__(self, state):
self.__dict__.update(state)

# Initialize self.random as an instance of NupicRandom derived from the
# previous numpy random state
randomState = state["random"]
if isinstance(randomState, numpy.random.mtrand.RandomState):
self.random = NupicRandom(randomState.randint(sys.maxint))


def _seed(self, seed=-1):
"""
Initialize the random seed
Expand Down Expand Up @@ -480,5 +490,3 @@ def write(self, proto):
proto.maxIndex = self.maxIndex
proto.bucketMap = [{"key": key, "value": value.tolist()}
for key, value in self.bucketMap.items()]


Original file line number Diff line number Diff line change
Expand Up @@ -438,7 +438,7 @@ def testEncodeInvalidInputType(self):

def testCapNProtoSerialization(self):
original = RandomDistributedScalarEncoder(name="enc", resolution=1.0, w=23,
n=500, offset = 0.0)
n=500, offset=0.0)

originalValue = original.encode(1)

Expand Down

0 comments on commit bfe0e5a

Please sign in to comment.