Skip to content

Commit b84ba1c

Browse files
authored
Update MyPy 14 (#210)
* move mypy config * some fixes * some fixes * some fixes * some fixes * some fixes * some fixes * remove reference np.float64 * remove unnesserary casting * fix type * fix import
1 parent 8b536b9 commit b84ba1c

File tree

7 files changed

+123
-100
lines changed

7 files changed

+123
-100
lines changed

.pre-commit-config.yaml

+2-2
Original file line numberDiff line numberDiff line change
@@ -12,14 +12,14 @@ ci:
1212

1313
repos:
1414
- repo: https://github.com/astral-sh/ruff-pre-commit
15-
rev: v0.8.3
15+
rev: v0.8.4
1616
hooks:
1717
- id: ruff
1818
args: ["--fix", "--output-format=full"]
1919
- id: ruff-format
2020
args: ["--line-length=100"]
2121
- repo: https://github.com/pre-commit/mirrors-mypy
22-
rev: v1.13.0
22+
rev: v1.14.0
2323
hooks:
2424
- id: mypy
2525
args: [--ignore-missing-imports]

mypy.ini

-15
This file was deleted.

pymc_bart/bart.py

+2-4
Original file line numberDiff line numberDiff line change
@@ -132,7 +132,7 @@ def __new__(
132132
alpha: float = 0.95,
133133
beta: float = 2.0,
134134
response: str = "constant",
135-
split_prior: Optional[npt.NDArray[np.float64]] = None,
135+
split_prior: Optional[npt.NDArray] = None,
136136
split_rules: Optional[list[SplitRule]] = None,
137137
separate_trees: Optional[bool] = False,
138138
**kwargs,
@@ -203,9 +203,7 @@ def get_moment(cls, rv, size, *rv_inputs):
203203
return mean
204204

205205

206-
def preprocess_xy(
207-
X: TensorLike, Y: TensorLike
208-
) -> tuple[npt.NDArray[np.float64], npt.NDArray[np.float64]]:
206+
def preprocess_xy(X: TensorLike, Y: TensorLike) -> tuple[npt.NDArray, npt.NDArray]:
209207
if isinstance(Y, (Series, DataFrame)):
210208
Y = Y.to_numpy()
211209
if isinstance(X, (Series, DataFrame)):

pymc_bart/pgbart.py

+42-31
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,8 @@
1616

1717
import numpy as np
1818
import numpy.typing as npt
19+
import pymc as pm
20+
import pytensor.tensor as pt
1921
from numba import njit
2022
from pymc.initial_point import PointType
2123
from pymc.model import Model, modelcontext
@@ -120,15 +122,15 @@ class PGBART(ArrayStepShared):
120122
"tune": (bool, []),
121123
}
122124

123-
def __init__( # noqa: PLR0915
125+
def __init__( # noqa: PLR0912, PLR0915
124126
self,
125-
vars=None, # pylint: disable=redefined-builtin
127+
vars: list[pm.Distribution] | None = None,
126128
num_particles: int = 10,
127129
batch: tuple[float, float] = (0.1, 0.1),
128130
model: Optional[Model] = None,
129131
initial_point: PointType | None = None,
130-
compile_kwargs: dict | None = None, # pylint: disable=unused-argument
131-
):
132+
compile_kwargs: dict | None = None,
133+
) -> None:
132134
model = modelcontext(model)
133135
if initial_point is None:
134136
initial_point = model.initial_point()
@@ -137,6 +139,10 @@ def __init__( # noqa: PLR0915
137139
else:
138140
vars = [model.rvs_to_values.get(var, var) for var in vars]
139141
vars = inputvars(vars)
142+
143+
if vars is None:
144+
raise ValueError("Unable to find variables to sample")
145+
140146
value_bart = vars[0]
141147
self.bart = model.values_to_rvs[value_bart].owner.op
142148

@@ -325,7 +331,7 @@ def normalize(self, particles: list[ParticleTree]) -> float:
325331
return wei / wei.sum()
326332

327333
def resample(
328-
self, particles: list[ParticleTree], normalized_weights: npt.NDArray[np.float64]
334+
self, particles: list[ParticleTree], normalized_weights: npt.NDArray
329335
) -> list[ParticleTree]:
330336
"""
331337
Use systematic resample for all but the first particle
@@ -347,7 +353,7 @@ def resample(
347353
return particles
348354

349355
def get_particle_tree(
350-
self, particles: list[ParticleTree], normalized_weights: npt.NDArray[np.float64]
356+
self, particles: list[ParticleTree], normalized_weights: npt.NDArray
351357
) -> tuple[ParticleTree, Tree]:
352358
"""
353359
Sample a new particle and associated tree
@@ -359,7 +365,7 @@ def get_particle_tree(
359365

360366
return new_particle, new_particle.tree
361367

362-
def systematic(self, normalized_weights: npt.NDArray[np.float64]) -> npt.NDArray[np.int_]:
368+
def systematic(self, normalized_weights: npt.NDArray) -> npt.NDArray[np.int_]:
363369
"""
364370
Systematic resampling.
365371
@@ -395,7 +401,7 @@ def update_weight(self, particle: ParticleTree, odim: int) -> None:
395401
particle.log_weight = new_likelihood
396402

397403
@staticmethod
398-
def competence(var, has_grad):
404+
def competence(var: pm.Distribution, has_grad: bool) -> Competence:
399405
"""PGBART is only suitable for BART distributions."""
400406
dist = getattr(var.owner, "op", None)
401407
if isinstance(dist, BARTRV):
@@ -406,12 +412,12 @@ def competence(var, has_grad):
406412
class RunningSd:
407413
"""Welford's online algorithm for computing the variance/standard deviation"""
408414

409-
def __init__(self, shape: tuple) -> None:
415+
def __init__(self, shape: tuple[int, ...]) -> None:
410416
self.count = 0 # number of data points
411417
self.mean = np.zeros(shape) # running mean
412418
self.m_2 = np.zeros(shape) # running second moment
413419

414-
def update(self, new_value: npt.NDArray[np.float64]) -> Union[float, npt.NDArray[np.float64]]:
420+
def update(self, new_value: npt.NDArray) -> Union[float, npt.NDArray]:
415421
self.count = self.count + 1
416422
self.mean, self.m_2, std = _update(self.count, self.mean, self.m_2, new_value)
417423
return fast_mean(std)
@@ -420,10 +426,10 @@ def update(self, new_value: npt.NDArray[np.float64]) -> Union[float, npt.NDArray
420426
@njit
421427
def _update(
422428
count: int,
423-
mean: npt.NDArray[np.float64],
424-
m_2: npt.NDArray[np.float64],
425-
new_value: npt.NDArray[np.float64],
426-
) -> tuple[npt.NDArray[np.float64], npt.NDArray[np.float64], Union[float, npt.NDArray[np.float64]]]:
429+
mean: npt.NDArray,
430+
m_2: npt.NDArray,
431+
new_value: npt.NDArray,
432+
) -> tuple[npt.NDArray, npt.NDArray, Union[float, npt.NDArray]]:
427433
delta = new_value - mean
428434
mean += delta / count
429435
delta2 = new_value - mean
@@ -434,7 +440,7 @@ def _update(
434440

435441

436442
class SampleSplittingVariable:
437-
def __init__(self, alpha_vec: npt.NDArray[np.float64]) -> None:
443+
def __init__(self, alpha_vec: npt.NDArray) -> None:
438444
"""
439445
Sample splitting variables proportional to `alpha_vec`.
440446
@@ -547,16 +553,16 @@ def filter_missing_values(available_splitting_values, idx_data_points, missing_d
547553

548554

549555
def draw_leaf_value(
550-
y_mu_pred: npt.NDArray[np.float64],
551-
x_mu: npt.NDArray[np.float64],
556+
y_mu_pred: npt.NDArray,
557+
x_mu: npt.NDArray,
552558
m: int,
553-
norm: npt.NDArray[np.float64],
559+
norm: npt.NDArray,
554560
shape: int,
555561
response: str,
556-
) -> tuple[npt.NDArray[np.float64], Optional[npt.NDArray[np.float64]]]:
562+
) -> tuple[npt.NDArray, Optional[npt.NDArray]]:
557563
"""Draw Gaussian distributed leaf values."""
558564
linear_params = None
559-
mu_mean = np.empty(shape)
565+
mu_mean: npt.NDArray
560566
if y_mu_pred.size == 0:
561567
return np.zeros(shape), linear_params
562568

@@ -571,7 +577,7 @@ def draw_leaf_value(
571577

572578

573579
@njit
574-
def fast_mean(ari: npt.NDArray[np.float64]) -> Union[float, npt.NDArray[np.float64]]:
580+
def fast_mean(ari: npt.NDArray) -> Union[float, npt.NDArray]:
575581
"""Use Numba to speed up the computation of the mean."""
576582
if ari.ndim == 1:
577583
count = ari.shape[0]
@@ -590,11 +596,11 @@ def fast_mean(ari: npt.NDArray[np.float64]) -> Union[float, npt.NDArray[np.float
590596

591597
@njit
592598
def fast_linear_fit(
593-
x: npt.NDArray[np.float64],
594-
y: npt.NDArray[np.float64],
599+
x: npt.NDArray,
600+
y: npt.NDArray,
595601
m: int,
596-
norm: npt.NDArray[np.float64],
597-
) -> tuple[npt.NDArray[np.float64], list[npt.NDArray[np.float64]]]:
602+
norm: npt.NDArray,
603+
) -> tuple[npt.NDArray, list[npt.NDArray]]:
598604
n = len(x)
599605
y = y / m + np.expand_dims(norm, axis=1)
600606

@@ -678,17 +684,17 @@ def update(self):
678684

679685
@njit
680686
def inverse_cdf(
681-
single_uniform: npt.NDArray[np.float64], normalized_weights: npt.NDArray[np.float64]
687+
single_uniform: npt.NDArray, normalized_weights: npt.NDArray
682688
) -> npt.NDArray[np.int_]:
683689
"""
684690
Inverse CDF algorithm for a finite distribution.
685691
686692
Parameters
687693
----------
688-
single_uniform: npt.NDArray[np.float64]
694+
single_uniform: npt.NDArray
689695
Ordered points in [0,1]
690696
691-
normalized_weights: npt.NDArray[np.float64])
697+
normalized_weights: npt.NDArray)
692698
Normalized weights
693699
694700
Returns
@@ -711,7 +717,7 @@ def inverse_cdf(
711717

712718

713719
@njit
714-
def jitter_duplicated(array: npt.NDArray[np.float64], std: float) -> npt.NDArray[np.float64]:
720+
def jitter_duplicated(array: npt.NDArray, std: float) -> npt.NDArray:
715721
"""
716722
Jitter duplicated values.
717723
"""
@@ -727,12 +733,17 @@ def jitter_duplicated(array: npt.NDArray[np.float64], std: float) -> npt.NDArray
727733

728734

729735
@njit
730-
def are_whole_number(array: npt.NDArray[np.float64]) -> np.bool_:
736+
def are_whole_number(array: npt.NDArray) -> np.bool_:
731737
"""Check if all values in array are whole numbers"""
732738
return np.all(np.mod(array[~np.isnan(array)], 1) == 0)
733739

734740

735-
def logp(point, out_vars, vars, shared): # pylint: disable=redefined-builtin
741+
def logp(
742+
point,
743+
out_vars: list[pm.Distribution],
744+
vars: list[pm.Distribution],
745+
shared: list[pt.TensorVariable],
746+
):
736747
"""Compile PyTensor function of the model and the input and output variables.
737748
738749
Parameters

0 commit comments

Comments
 (0)