1212# See the License for the specific language governing permissions and
1313# limitations under the License.
1414
15- from typing import Optional , Union
1615
1716import numpy as np
1817import numpy .typing as npt
@@ -128,7 +127,7 @@ def __init__( # noqa: PLR0912, PLR0915
128127 vars : list [pm .Distribution ] | None = None ,
129128 num_particles : int = 10 ,
130129 batch : tuple [float , float ] = (0.1 , 0.1 ),
131- model : Optional [ Model ] = None ,
130+ model : Model | None = None ,
132131 initial_point : PointType | None = None ,
133132 compile_kwargs : dict | None = None ,
134133 ** kwargs , # Accept additional kwargs for compound sampling
@@ -445,7 +444,7 @@ def __init__(self, shape: tuple[int, ...]) -> None:
445444 self .mean = np .zeros (shape ) # running mean
446445 self .m_2 = np .zeros (shape ) # running second moment
447446
448- def update (self , new_value : npt .NDArray ) -> Union [ float , npt .NDArray ] :
447+ def update (self , new_value : npt .NDArray ) -> float | npt .NDArray :
449448 self .count = self .count + 1
450449 self .mean , self .m_2 , std = _update (self .count , self .mean , self .m_2 , new_value )
451450 return fast_mean (std )
@@ -457,7 +456,7 @@ def _update(
457456 mean : npt .NDArray ,
458457 m_2 : npt .NDArray ,
459458 new_value : npt .NDArray ,
460- ) -> tuple [npt .NDArray , npt .NDArray , Union [ float , npt .NDArray ] ]:
459+ ) -> tuple [npt .NDArray , npt .NDArray , float | npt .NDArray ]:
461460 delta = new_value - mean
462461 mean += delta / count
463462 delta2 = new_value - mean
@@ -477,7 +476,7 @@ def __init__(self, alpha_vec: npt.NDArray) -> None:
477476 """
478477 self .enu = list (enumerate (np .cumsum (alpha_vec / alpha_vec .sum ())))
479478
480- def rvs (self ) -> Union [ int , tuple [int , float ] ]:
479+ def rvs (self ) -> int | tuple [int , float ]:
481480 rnd : float = np .random .random ()
482481 for i , val in self .enu :
483482 if rnd <= val :
@@ -587,7 +586,7 @@ def draw_leaf_value(
587586 norm : npt .NDArray ,
588587 shape : int ,
589588 response : str ,
590- ) -> tuple [npt .NDArray , Optional [ npt .NDArray ] ]:
589+ ) -> tuple [npt .NDArray , npt .NDArray | None ]:
591590 """Draw Gaussian distributed leaf values."""
592591 linear_params = None
593592 mu_mean : npt .NDArray
@@ -605,7 +604,7 @@ def draw_leaf_value(
605604
606605
607606@njit
608- def fast_mean (ari : npt .NDArray ) -> Union [ float , npt .NDArray ] :
607+ def fast_mean (ari : npt .NDArray ) -> float | npt .NDArray :
609608 """Use Numba to speed up the computation of the mean."""
610609 if ari .ndim == 1 :
611610 count = ari .shape [0 ]
0 commit comments