From db1584ec9613bda3743e98b530953a294618c06b Mon Sep 17 00:00:00 2001 From: takkyu2 Date: Sun, 7 Dec 2025 15:14:17 -0800 Subject: [PATCH 1/3] fix some type errors --- src/treequest/algos/ab_mcts_m/_ab_mcts_m_imports.py | 8 ++++---- src/treequest/algos/ab_mcts_m/algo.py | 2 +- src/treequest/algos/ab_mcts_m/numpyro_utils.py | 2 +- src/treequest/vis/renderers/html.py | 6 +++--- 4 files changed, 9 insertions(+), 9 deletions(-) diff --git a/src/treequest/algos/ab_mcts_m/_ab_mcts_m_imports.py b/src/treequest/algos/ab_mcts_m/_ab_mcts_m_imports.py index c8588f5..1ce6443 100644 --- a/src/treequest/algos/ab_mcts_m/_ab_mcts_m_imports.py +++ b/src/treequest/algos/ab_mcts_m/_ab_mcts_m_imports.py @@ -1,14 +1,14 @@ from treequest.imports import try_import with try_import() as _import: - import jax + import jax # type: ignore[import-not-found] from packaging.version import Version # TODO: Remove this hotfix after numpyro fixes incompatibility with jax>=0.7 # https://github.com/pyro-ppl/numpyro/issues/2051 if Version(jax.__version__) >= Version("0.7.0"): - import jax.experimental.pjit as _pjit - from jax.extend.core.primitives import jit_p + import jax.experimental.pjit as _pjit # type: ignore[import-not-found] + from jax.extend.core.primitives import jit_p # type: ignore[import-not-found] _pjit.pjit_p = jit_p # type: ignore import numpy as np @@ -16,7 +16,7 @@ import pandas as pd # type: ignore import pymc as pm # type: ignore from pymc.sampling.jax import sample_numpyro_nuts # type: ignore - from xarray import DataArray + from xarray import DataArray # type: ignore[import-not-found] __all__ = [ "jax", diff --git a/src/treequest/algos/ab_mcts_m/algo.py b/src/treequest/algos/ab_mcts_m/algo.py index d54d2dd..9bb8520 100644 --- a/src/treequest/algos/ab_mcts_m/algo.py +++ b/src/treequest/algos/ab_mcts_m/algo.py @@ -19,7 +19,7 @@ StateT = TypeVar("StateT") -_WORKER_ALGO = None +_WORKER_ALGO: Optional["ABMCTSM"] = None def _worker_init_abmctsm(config: dict, per_worker_cpu_devices: int): diff --git a/src/treequest/algos/ab_mcts_m/numpyro_utils.py b/src/treequest/algos/ab_mcts_m/numpyro_utils.py index 7b05c33..4c0d6b0 100644 --- a/src/treequest/algos/ab_mcts_m/numpyro_utils.py +++ b/src/treequest/algos/ab_mcts_m/numpyro_utils.py @@ -1,7 +1,7 @@ def initialize_numpyro(num_cpu_devices: int = 4): # For 4 parallel chains import os - import numpyro # type: ignore[import-untyped] + import numpyro # type: ignore[import-not-found] # To avoid file lock error: https://github.com/pymc-devs/pymc/issues/6818 os.environ["PYTENSOR_FLAGS"] = ( diff --git a/src/treequest/vis/renderers/html.py b/src/treequest/vis/renderers/html.py index 6489978..c6a9f39 100644 --- a/src/treequest/vis/renderers/html.py +++ b/src/treequest/vis/renderers/html.py @@ -4,14 +4,14 @@ from typing import Callable, Dict, List, Optional, Union from treequest.vis.errors import DependencyNotFoundError, RenderError -from treequest.vis.snapshot import VisualizationSnapshot -from treequest.vis.renderers.json_yaml import snapshot_to_dict from treequest.vis.renderers.color_utils import ( ROOT_COLOR, ColorMap, apply_status_color, resolve_colormap, ) +from treequest.vis.renderers.json_yaml import snapshot_to_dict +from treequest.vis.snapshot import VisualizationSnapshot def _get_d3_js() -> str: @@ -71,7 +71,7 @@ def render_html( RenderError: If rendering fails """ try: - from jinja2 import Template + from jinja2 import Template # type: ignore[import-not-found] except ImportError: raise DependencyNotFoundError( "jinja2 is not installed. Install it with: pip install treequest[vis]" From 8e0afbc75f1c28d157301d0ba457b6fc5c1d49fd Mon Sep 17 00:00:00 2001 From: takkyu2 Date: Sun, 7 Dec 2025 15:16:27 -0800 Subject: [PATCH 2/3] fix type errors --- src/treequest/algos/ab_mcts_m/_ab_mcts_m_imports.py | 8 ++++---- src/treequest/algos/ab_mcts_m/numpyro_utils.py | 2 +- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/src/treequest/algos/ab_mcts_m/_ab_mcts_m_imports.py b/src/treequest/algos/ab_mcts_m/_ab_mcts_m_imports.py index 1ce6443..c8588f5 100644 --- a/src/treequest/algos/ab_mcts_m/_ab_mcts_m_imports.py +++ b/src/treequest/algos/ab_mcts_m/_ab_mcts_m_imports.py @@ -1,14 +1,14 @@ from treequest.imports import try_import with try_import() as _import: - import jax # type: ignore[import-not-found] + import jax from packaging.version import Version # TODO: Remove this hotfix after numpyro fixes incompatibility with jax>=0.7 # https://github.com/pyro-ppl/numpyro/issues/2051 if Version(jax.__version__) >= Version("0.7.0"): - import jax.experimental.pjit as _pjit # type: ignore[import-not-found] - from jax.extend.core.primitives import jit_p # type: ignore[import-not-found] + import jax.experimental.pjit as _pjit + from jax.extend.core.primitives import jit_p _pjit.pjit_p = jit_p # type: ignore import numpy as np @@ -16,7 +16,7 @@ import pandas as pd # type: ignore import pymc as pm # type: ignore from pymc.sampling.jax import sample_numpyro_nuts # type: ignore - from xarray import DataArray # type: ignore[import-not-found] + from xarray import DataArray __all__ = [ "jax", diff --git a/src/treequest/algos/ab_mcts_m/numpyro_utils.py b/src/treequest/algos/ab_mcts_m/numpyro_utils.py index 4c0d6b0..7b05c33 100644 --- a/src/treequest/algos/ab_mcts_m/numpyro_utils.py +++ b/src/treequest/algos/ab_mcts_m/numpyro_utils.py @@ -1,7 +1,7 @@ def initialize_numpyro(num_cpu_devices: int = 4): # For 4 parallel chains import os - import numpyro # type: ignore[import-not-found] + import numpyro # type: ignore[import-untyped] # To avoid file lock error: https://github.com/pymc-devs/pymc/issues/6818 os.environ["PYTENSOR_FLAGS"] = ( From e30ef4149e23debaa63c2ad4a7637f61accfb543 Mon Sep 17 00:00:00 2001 From: takkyu2 Date: Sun, 7 Dec 2025 15:22:29 -0800 Subject: [PATCH 3/3] revert unnecessary fix --- src/treequest/vis/renderers/html.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/treequest/vis/renderers/html.py b/src/treequest/vis/renderers/html.py index c6a9f39..ea6a71d 100644 --- a/src/treequest/vis/renderers/html.py +++ b/src/treequest/vis/renderers/html.py @@ -71,7 +71,7 @@ def render_html( RenderError: If rendering fails """ try: - from jinja2 import Template # type: ignore[import-not-found] + from jinja2 import Template except ImportError: raise DependencyNotFoundError( "jinja2 is not installed. Install it with: pip install treequest[vis]"