Skip to content

Fix conditional import of zarr #7645

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jan 14, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 23 additions & 14 deletions pymc/backends/zarr.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,14 +15,11 @@
from typing import Any

import arviz as az
import numcodecs
import numpy as np
import xarray as xr
import zarr

from arviz.data.base import make_attrs
from arviz.data.inference_data import WARMUP_TAG
from numcodecs.abc import Codec
from pytensor.tensor.variable import TensorVariable

import pymc
Expand All @@ -44,11 +41,23 @@
from pymc.util import UNSET, _UnsetType, get_default_varnames, is_transformed_name

try:
import numcodecs
import zarr

from numcodecs.abc import Codec
from zarr import Group
from zarr.storage import BaseStore, default_compressor
from zarr.sync import Synchronizer

_zarr_available = True
except ImportError:
from typing import TYPE_CHECKING, TypeVar

Check warning on line 54 in pymc/backends/zarr.py

View check run for this annotation

Codecov / codecov/patch

pymc/backends/zarr.py#L54

Added line #L54 was not covered by tests

if not TYPE_CHECKING:
Codec = TypeVar("Codec")
Group = TypeVar("Group")
BaseStore = TypeVar("BaseStore")
Synchronizer = TypeVar("Synchronizer")

Check warning on line 60 in pymc/backends/zarr.py

View check run for this annotation

Codecov / codecov/patch

pymc/backends/zarr.py#L56-L60

Added lines #L56 - L60 were not covered by tests
_zarr_available = False


Expand Down Expand Up @@ -243,7 +252,7 @@

def get_initial_fill_value_and_codec(
dtype: Any,
) -> tuple[FILL_VALUE_TYPE, np.typing.DTypeLike, numcodecs.abc.Codec | None]:
) -> tuple[FILL_VALUE_TYPE, np.typing.DTypeLike, Codec | None]:
_dtype = np.dtype(dtype)
fill_value: FILL_VALUE_TYPE = None
codec = None
Expand Down Expand Up @@ -366,27 +375,27 @@
return [str(group_name) for group_name, _ in self.root.groups()]

@property
def posterior(self) -> zarr.Group:
def posterior(self) -> Group:
return self.root.posterior

@property
def unconstrained_posterior(self) -> zarr.Group:
def unconstrained_posterior(self) -> Group:
return self.root.unconstrained_posterior

@property
def sample_stats(self) -> zarr.Group:
def sample_stats(self) -> Group:
return self.root.sample_stats

@property
def constant_data(self) -> zarr.Group:
def constant_data(self) -> Group:
return self.root.constant_data

@property
def observed_data(self) -> zarr.Group:
def observed_data(self) -> Group:
return self.root.observed_data

@property
def _sampling_state(self) -> zarr.Group:
def _sampling_state(self) -> Group:
return self.root._sampling_state

def init_trace(
Expand Down Expand Up @@ -646,12 +655,12 @@

def init_group_with_empty(
self,
group: zarr.Group,
group: Group,
var_dtype_and_shape: dict[str, tuple[StatDtype, StatShape]],
chains: int,
draws: int,
extra_var_attrs: dict | None = None,
) -> zarr.Group:
) -> Group:
group_coords: dict[str, Any] = {"chain": range(chains), "draw": range(draws)}
for name, (_dtype, shape) in var_dtype_and_shape.items():
fill_value, dtype, object_codec = get_initial_fill_value_and_codec(_dtype)
Expand Down Expand Up @@ -689,8 +698,8 @@
array.attrs.update({"_ARRAY_DIMENSIONS": [dim]})
return group

def create_group(self, name: str, data_dict: dict[str, np.ndarray]) -> zarr.Group | None:
group: zarr.Group | None = None
def create_group(self, name: str, data_dict: dict[str, np.ndarray]) -> Group | None:
group: Group | None = None
if data_dict:
group_coords = {}
group = self.root.create_group(name=name, overwrite=True)
Expand Down
6 changes: 5 additions & 1 deletion pymc/sampling/mcmc.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,6 @@
from rich.theme import Theme
from threadpoolctl import threadpool_limits
from typing_extensions import Protocol
from zarr.storage import MemoryStore

import pymc as pm

Expand Down Expand Up @@ -80,6 +79,11 @@
)
from pymc.vartypes import discrete_types

try:
from zarr.storage import MemoryStore
except ImportError:
MemoryStore = type("MemoryStore", (), {})

Check warning on line 85 in pymc/sampling/mcmc.py

View check run for this annotation

Codecov / codecov/patch

pymc/sampling/mcmc.py#L84-L85

Added lines #L84 - L85 were not covered by tests

sys.setrecursionlimit(10000)

__all__ = [
Expand Down
Loading