Skip to content
This repository has been archived by the owner on Jan 10, 2025. It is now read-only.

Commit

Permalink
statistics
Browse files Browse the repository at this point in the history
  • Loading branch information
floriankrb committed Nov 17, 2023
1 parent 0205451 commit ef6aab6
Showing 1 changed file with 52 additions and 67 deletions.
119 changes: 52 additions & 67 deletions ecml_tools/create/zarr.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,47 @@
LOG = logging.getLogger(__name__)


class StatisticsRegistry:
name = 'statistics'
#names = [ "mean", "stdev", "minimum", "maximum", "sums", "squares", "count", ]
#build_names = [ "minimum", "maximum", "sums", "squares", "count", ]

def __init__(self, dirname, zarr_registry=False, overwrite=False):
self.dirname = dirname
self.data_dirname = os.path.join(self.dirname, self.name)
self.overwrite = overwrite

def create(self):
assert not os.path.exists(self.data_dirname), self.data_dirname
os.makedirs(self.data_dirname, exist_ok=True)
if self.zarr_registry:
self.zarr_registry.add_to_history(f"{self.name}_registry_initialised", **{f"{self.name}_version"}=2)

def delete(self):
import shutil
shutil.rmtree(self.data_dirname)

def __setitem__(self, key, data):
path = self.dirname + "/" + key + ".npz"
if self.overwrite is False:
assert not os.path.exist(path), f"{path} already exists"
LOG.info(f"Writing {self.name} for {key}")
with open(path, 'wb') as f:
pickle.dump((key, data), f)
LOG.info(f"Written {self.name} data for {key} in {path}")

def read_all(self, expected_lenghts=None):
# use glob to read all pickles
files = glob.glob(self.data_dirname + "/*.npz")
LOG.info(f"Reading {self.name} data, found {len(files)} for {self.name} in {self.dirname}")
dic = {}
for f in files:
with open(f, 'rb') as f:
key, data = pickle.load(f)
assert key not in dic, f"Duplicate key {key}"
yield key, data


def add_zarr_dataset(
*,
name,
Expand Down Expand Up @@ -47,21 +88,23 @@ def add_zarr_dataset(
return a


class ZarrRegistry:
synchronizer_name = None # to be defined in subclasses

def __init__(self, path):
assert self.synchronizer_name is not None, self.synchronizer_name
class ZarrBuiltRegistry:
name_lengths = "lengths"
name_flags = "flags"
lengths = None
flags = None
z = None

def __init__(self, path, synchronizer_path=None):
import zarr

assert isinstance(path, str), path
self.zarr_path = path
self.synchronizer = zarr.ProcessSynchronizer(self._synchronizer_path)

@property
def _synchronizer_path(self):
return self.zarr_path + "-" + self.synchronizer_name + ".sync"
if synchronizer_path is None:
synchronizer_path = self.zarr_path + ".sync"
self.synchronizer_path = synchronizer_path
self.synchronizer = zarr.ProcessSynchronizer(self.synchronizer_path)

def _open_write(self):
import zarr
Expand Down Expand Up @@ -94,64 +137,6 @@ def add_to_history(self, action, **kwargs):
z.attrs["history"] = history


class ZarrStatisticsRegistry(ZarrRegistry):
names = [
"mean",
"stdev",
"minimum",
"maximum",
"sums",
"squares",
"count",
]
build_names = [
"minimum",
"maximum",
"sums",
"squares",
"count",
]
synchronizer_name = "statistics"

def __init__(self, path):
super().__init__(path)

def create(self):
z = self._open_read()
shape = z["data"].shape
shape = (shape[0], shape[1])

for name in self.build_names:
if name == "count":
self.new_dataset(name=name, shape=shape, fill_value=0, dtype=np.int64)
else:
self.new_dataset(
name=name, shape=shape, fill_value=np.nan, dtype=np.float64
)
self.add_to_history("statistics_initialised")

def __setitem__(self, key, stats):
z = self._open_write()

LOG.info(f"Writting stats for {key}")
for name in self.build_names:
LOG.info(f"Writting stats for {key} {name} {stats[name].shape}")
z["_build"][name][key] = stats[name]
LOG.info(f"Written stats for {key}")

def get_by_name(self, name):
z = self._open_read()
return z["_build"][name]


class ZarrBuiltRegistry(ZarrRegistry):
name_lengths = "lengths"
name_flags = "flags"
lengths = None
flags = None
z = None
synchronizer_name = "build"

def get_slice_for(self, i):
lengths = self.get_lengths()
assert i >= 0 and i < len(lengths)
Expand Down

0 comments on commit ef6aab6

Please sign in to comment.