Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 23 additions & 14 deletions pymc/backends/zarr.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,14 +15,11 @@
from typing import Any

import arviz as az
import numcodecs
import numpy as np
import xarray as xr
import zarr

from arviz.data.base import make_attrs
from arviz.data.inference_data import WARMUP_TAG
from numcodecs.abc import Codec
from pytensor.tensor.variable import TensorVariable

import pymc
Expand All @@ -44,11 +41,23 @@
from pymc.util import UNSET, _UnsetType, get_default_varnames, is_transformed_name

try:
import numcodecs
import zarr

from numcodecs.abc import Codec
from zarr import Group
from zarr.storage import BaseStore, default_compressor
from zarr.sync import Synchronizer

_zarr_available = True
except ImportError:
from typing import TYPE_CHECKING, TypeVar

if not TYPE_CHECKING:
Codec = TypeVar("Codec")
Group = TypeVar("Group")
BaseStore = TypeVar("BaseStore")
Synchronizer = TypeVar("Synchronizer")
_zarr_available = False


Expand Down Expand Up @@ -243,7 +252,7 @@ def flush(self):

def get_initial_fill_value_and_codec(
dtype: Any,
) -> tuple[FILL_VALUE_TYPE, np.typing.DTypeLike, numcodecs.abc.Codec | None]:
) -> tuple[FILL_VALUE_TYPE, np.typing.DTypeLike, Codec | None]:
_dtype = np.dtype(dtype)
fill_value: FILL_VALUE_TYPE = None
codec = None
Expand Down Expand Up @@ -366,27 +375,27 @@ def groups(self) -> list[str]:
return [str(group_name) for group_name, _ in self.root.groups()]

@property
def posterior(self) -> zarr.Group:
def posterior(self) -> Group:
return self.root.posterior

@property
def unconstrained_posterior(self) -> zarr.Group:
def unconstrained_posterior(self) -> Group:
return self.root.unconstrained_posterior

@property
def sample_stats(self) -> zarr.Group:
def sample_stats(self) -> Group:
return self.root.sample_stats

@property
def constant_data(self) -> zarr.Group:
def constant_data(self) -> Group:
return self.root.constant_data

@property
def observed_data(self) -> zarr.Group:
def observed_data(self) -> Group:
return self.root.observed_data

@property
def _sampling_state(self) -> zarr.Group:
def _sampling_state(self) -> Group:
return self.root._sampling_state

def init_trace(
Expand Down Expand Up @@ -646,12 +655,12 @@ def init_sampling_state_group(self, tune: int, chains: int):

def init_group_with_empty(
self,
group: zarr.Group,
group: Group,
var_dtype_and_shape: dict[str, tuple[StatDtype, StatShape]],
chains: int,
draws: int,
extra_var_attrs: dict | None = None,
) -> zarr.Group:
) -> Group:
group_coords: dict[str, Any] = {"chain": range(chains), "draw": range(draws)}
for name, (_dtype, shape) in var_dtype_and_shape.items():
fill_value, dtype, object_codec = get_initial_fill_value_and_codec(_dtype)
Expand Down Expand Up @@ -689,8 +698,8 @@ def init_group_with_empty(
array.attrs.update({"_ARRAY_DIMENSIONS": [dim]})
return group

def create_group(self, name: str, data_dict: dict[str, np.ndarray]) -> zarr.Group | None:
group: zarr.Group | None = None
def create_group(self, name: str, data_dict: dict[str, np.ndarray]) -> Group | None:
group: Group | None = None
if data_dict:
group_coords = {}
group = self.root.create_group(name=name, overwrite=True)
Expand Down
6 changes: 5 additions & 1 deletion pymc/sampling/mcmc.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,6 @@
from rich.theme import Theme
from threadpoolctl import threadpool_limits
from typing_extensions import Protocol
from zarr.storage import MemoryStore

import pymc as pm

Expand Down Expand Up @@ -80,6 +79,11 @@
)
from pymc.vartypes import discrete_types

try:
from zarr.storage import MemoryStore
except ImportError:
MemoryStore = type("MemoryStore", (), {})

sys.setrecursionlimit(10000)

__all__ = [
Expand Down