Skip to content

Commit 11b5d54

Browse files
[pre-commit.ci] pre-commit autoupdate (#244)
* [pre-commit.ci] pre-commit autoupdate updates: - [github.com/astral-sh/ruff-pre-commit: v0.13.3 → v0.14.0](astral-sh/ruff-pre-commit@v0.13.3...v0.14.0) * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent 99f14fa commit 11b5d54

File tree

4 files changed

+19
-22
lines changed

4 files changed

+19
-22
lines changed

.pre-commit-config.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ ci:
1212

1313
repos:
1414
- repo: https://github.com/astral-sh/ruff-pre-commit
15-
rev: v0.13.3
15+
rev: v0.14.0
1616
hooks:
1717
- id: ruff
1818
args: ["--fix", "--output-format=full"]

pymc_bart/bart.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,6 @@
1616

1717
import warnings
1818
from multiprocessing import Manager
19-
from typing import Optional
2019

2120
import numpy as np
2221
import numpy.typing as npt
@@ -130,9 +129,9 @@ def __new__(
130129
alpha: float = 0.95,
131130
beta: float = 2.0,
132131
response: str = "constant",
133-
split_prior: Optional[npt.NDArray] = None,
134-
split_rules: Optional[list[SplitRule]] = None,
135-
separate_trees: Optional[bool] = False,
132+
split_prior: npt.NDArray | None = None,
133+
split_rules: list[SplitRule] | None = None,
134+
separate_trees: bool | None = False,
136135
**kwargs,
137136
):
138137
if response in ["linear", "mix"]:

pymc_bart/pgbart.py

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15-
from typing import Optional, Union
1615

1716
import numpy as np
1817
import numpy.typing as npt
@@ -128,7 +127,7 @@ def __init__( # noqa: PLR0912, PLR0915
128127
vars: list[pm.Distribution] | None = None,
129128
num_particles: int = 10,
130129
batch: tuple[float, float] = (0.1, 0.1),
131-
model: Optional[Model] = None,
130+
model: Model | None = None,
132131
initial_point: PointType | None = None,
133132
compile_kwargs: dict | None = None,
134133
**kwargs, # Accept additional kwargs for compound sampling
@@ -445,7 +444,7 @@ def __init__(self, shape: tuple[int, ...]) -> None:
445444
self.mean = np.zeros(shape) # running mean
446445
self.m_2 = np.zeros(shape) # running second moment
447446

448-
def update(self, new_value: npt.NDArray) -> Union[float, npt.NDArray]:
447+
def update(self, new_value: npt.NDArray) -> float | npt.NDArray:
449448
self.count = self.count + 1
450449
self.mean, self.m_2, std = _update(self.count, self.mean, self.m_2, new_value)
451450
return fast_mean(std)
@@ -457,7 +456,7 @@ def _update(
457456
mean: npt.NDArray,
458457
m_2: npt.NDArray,
459458
new_value: npt.NDArray,
460-
) -> tuple[npt.NDArray, npt.NDArray, Union[float, npt.NDArray]]:
459+
) -> tuple[npt.NDArray, npt.NDArray, float | npt.NDArray]:
461460
delta = new_value - mean
462461
mean += delta / count
463462
delta2 = new_value - mean
@@ -477,7 +476,7 @@ def __init__(self, alpha_vec: npt.NDArray) -> None:
477476
"""
478477
self.enu = list(enumerate(np.cumsum(alpha_vec / alpha_vec.sum())))
479478

480-
def rvs(self) -> Union[int, tuple[int, float]]:
479+
def rvs(self) -> int | tuple[int, float]:
481480
rnd: float = np.random.random()
482481
for i, val in self.enu:
483482
if rnd <= val:
@@ -587,7 +586,7 @@ def draw_leaf_value(
587586
norm: npt.NDArray,
588587
shape: int,
589588
response: str,
590-
) -> tuple[npt.NDArray, Optional[npt.NDArray]]:
589+
) -> tuple[npt.NDArray, npt.NDArray | None]:
591590
"""Draw Gaussian distributed leaf values."""
592591
linear_params = None
593592
mu_mean: npt.NDArray
@@ -605,7 +604,7 @@ def draw_leaf_value(
605604

606605

607606
@njit
608-
def fast_mean(ari: npt.NDArray) -> Union[float, npt.NDArray]:
607+
def fast_mean(ari: npt.NDArray) -> float | npt.NDArray:
609608
"""Use Numba to speed up the computation of the mean."""
610609
if ari.ndim == 1:
611610
count = ari.shape[0]

pymc_bart/tree.py

Lines changed: 9 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,6 @@
1414

1515
from collections.abc import Generator
1616
from functools import lru_cache
17-
from typing import Optional, Union
1817

1918
import numpy as np
2019
import numpy.typing as npt
@@ -40,9 +39,9 @@ def __init__(
4039
self,
4140
value: npt.NDArray = np.array([-1.0]),
4241
nvalue: int = 0,
43-
idx_data_points: Optional[npt.NDArray[np.int_]] = None,
42+
idx_data_points: npt.NDArray[np.int_] | None = None,
4443
idx_split_variable: int = -1,
45-
linear_params: Optional[list[npt.NDArray]] = None,
44+
linear_params: list[npt.NDArray] | None = None,
4645
) -> None:
4746
self.value = value
4847
self.nvalue = nvalue
@@ -55,9 +54,9 @@ def new_leaf_node(
5554
cls,
5655
value: npt.NDArray,
5756
nvalue: int = 0,
58-
idx_data_points: Optional[npt.NDArray[np.int_]] = None,
57+
idx_data_points: npt.NDArray[np.int_] | None = None,
5958
idx_split_variable: int = -1,
60-
linear_params: Optional[list[npt.NDArray]] = None,
59+
linear_params: list[npt.NDArray] | None = None,
6160
) -> "Node":
6261
return cls(
6362
value=value,
@@ -124,7 +123,7 @@ def __init__(
124123
tree_structure: dict[int, Node],
125124
output: npt.NDArray,
126125
split_rules: list[SplitRule],
127-
idx_leaf_nodes: Optional[list[int]] = None,
126+
idx_leaf_nodes: list[int] | None = None,
128127
) -> None:
129128
self.tree_structure = tree_structure
130129
self.idx_leaf_nodes = idx_leaf_nodes
@@ -135,7 +134,7 @@ def __init__(
135134
def new_tree(
136135
cls,
137136
leaf_node_value: npt.NDArray,
138-
idx_data_points: Optional[npt.NDArray[np.int_]],
137+
idx_data_points: npt.NDArray[np.int_] | None,
139138
num_observations: int,
140139
shape: int,
141140
split_rules: list[SplitRule],
@@ -234,7 +233,7 @@ def _predict(self) -> npt.NDArray:
234233
def predict(
235234
self,
236235
x: npt.NDArray,
237-
excluded: Optional[list[int]] = None,
236+
excluded: list[int] | None = None,
238237
shape: int = 1,
239238
) -> npt.NDArray:
240239
"""
@@ -260,8 +259,8 @@ def predict(
260259
def _traverse_tree(
261260
self,
262261
X: npt.NDArray,
263-
excluded: Optional[list[int]] = None,
264-
shape: Union[int, tuple[int, ...]] = 1,
262+
excluded: list[int] | None = None,
263+
shape: int | tuple[int, ...] = 1,
265264
) -> npt.NDArray:
266265
"""
267266
Traverse the tree starting from the root node given an (un)observed point.

0 commit comments

Comments
 (0)