diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 48316f4..5b31494 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -12,7 +12,7 @@ ci: repos: - repo: https://github.com/astral-sh/ruff-pre-commit - rev: v0.13.3 + rev: v0.14.0 hooks: - id: ruff args: ["--fix", "--output-format=full"] diff --git a/pymc_bart/bart.py b/pymc_bart/bart.py index 233d33e..1c99bce 100644 --- a/pymc_bart/bart.py +++ b/pymc_bart/bart.py @@ -16,7 +16,6 @@ import warnings from multiprocessing import Manager -from typing import Optional import numpy as np import numpy.typing as npt @@ -130,9 +129,9 @@ def __new__( alpha: float = 0.95, beta: float = 2.0, response: str = "constant", - split_prior: Optional[npt.NDArray] = None, - split_rules: Optional[list[SplitRule]] = None, - separate_trees: Optional[bool] = False, + split_prior: npt.NDArray | None = None, + split_rules: list[SplitRule] | None = None, + separate_trees: bool | None = False, **kwargs, ): if response in ["linear", "mix"]: diff --git a/pymc_bart/pgbart.py b/pymc_bart/pgbart.py index 87bd36a..72764b2 100644 --- a/pymc_bart/pgbart.py +++ b/pymc_bart/pgbart.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Optional, Union import numpy as np import numpy.typing as npt @@ -128,7 +127,7 @@ def __init__( # noqa: PLR0912, PLR0915 vars: list[pm.Distribution] | None = None, num_particles: int = 10, batch: tuple[float, float] = (0.1, 0.1), - model: Optional[Model] = None, + model: Model | None = None, initial_point: PointType | None = None, compile_kwargs: dict | None = None, **kwargs, # Accept additional kwargs for compound sampling @@ -445,7 +444,7 @@ def __init__(self, shape: tuple[int, ...]) -> None: self.mean = np.zeros(shape) # running mean self.m_2 = np.zeros(shape) # running second moment - def update(self, new_value: npt.NDArray) -> Union[float, npt.NDArray]: + def update(self, new_value: npt.NDArray) -> float | npt.NDArray: self.count = self.count + 1 self.mean, self.m_2, std = _update(self.count, self.mean, self.m_2, new_value) return fast_mean(std) @@ -457,7 +456,7 @@ def _update( mean: npt.NDArray, m_2: npt.NDArray, new_value: npt.NDArray, -) -> tuple[npt.NDArray, npt.NDArray, Union[float, npt.NDArray]]: +) -> tuple[npt.NDArray, npt.NDArray, float | npt.NDArray]: delta = new_value - mean mean += delta / count delta2 = new_value - mean @@ -477,7 +476,7 @@ def __init__(self, alpha_vec: npt.NDArray) -> None: """ self.enu = list(enumerate(np.cumsum(alpha_vec / alpha_vec.sum()))) - def rvs(self) -> Union[int, tuple[int, float]]: + def rvs(self) -> int | tuple[int, float]: rnd: float = np.random.random() for i, val in self.enu: if rnd <= val: @@ -587,7 +586,7 @@ def draw_leaf_value( norm: npt.NDArray, shape: int, response: str, -) -> tuple[npt.NDArray, Optional[npt.NDArray]]: +) -> tuple[npt.NDArray, npt.NDArray | None]: """Draw Gaussian distributed leaf values.""" linear_params = None mu_mean: npt.NDArray @@ -605,7 +604,7 @@ def draw_leaf_value( @njit -def fast_mean(ari: npt.NDArray) -> Union[float, npt.NDArray]: +def fast_mean(ari: npt.NDArray) -> float | npt.NDArray: """Use Numba to speed up the computation of the mean.""" if ari.ndim == 1: count = ari.shape[0] diff --git a/pymc_bart/tree.py b/pymc_bart/tree.py index 61e5050..250e16b 100644 --- a/pymc_bart/tree.py +++ b/pymc_bart/tree.py @@ -14,7 +14,6 @@ from collections.abc import Generator from functools import lru_cache -from typing import Optional, Union import numpy as np import numpy.typing as npt @@ -40,9 +39,9 @@ def __init__( self, value: npt.NDArray = np.array([-1.0]), nvalue: int = 0, - idx_data_points: Optional[npt.NDArray[np.int_]] = None, + idx_data_points: npt.NDArray[np.int_] | None = None, idx_split_variable: int = -1, - linear_params: Optional[list[npt.NDArray]] = None, + linear_params: list[npt.NDArray] | None = None, ) -> None: self.value = value self.nvalue = nvalue @@ -55,9 +54,9 @@ def new_leaf_node( cls, value: npt.NDArray, nvalue: int = 0, - idx_data_points: Optional[npt.NDArray[np.int_]] = None, + idx_data_points: npt.NDArray[np.int_] | None = None, idx_split_variable: int = -1, - linear_params: Optional[list[npt.NDArray]] = None, + linear_params: list[npt.NDArray] | None = None, ) -> "Node": return cls( value=value, @@ -124,7 +123,7 @@ def __init__( tree_structure: dict[int, Node], output: npt.NDArray, split_rules: list[SplitRule], - idx_leaf_nodes: Optional[list[int]] = None, + idx_leaf_nodes: list[int] | None = None, ) -> None: self.tree_structure = tree_structure self.idx_leaf_nodes = idx_leaf_nodes @@ -135,7 +134,7 @@ def __init__( def new_tree( cls, leaf_node_value: npt.NDArray, - idx_data_points: Optional[npt.NDArray[np.int_]], + idx_data_points: npt.NDArray[np.int_] | None, num_observations: int, shape: int, split_rules: list[SplitRule], @@ -234,7 +233,7 @@ def _predict(self) -> npt.NDArray: def predict( self, x: npt.NDArray, - excluded: Optional[list[int]] = None, + excluded: list[int] | None = None, shape: int = 1, ) -> npt.NDArray: """ @@ -260,8 +259,8 @@ def predict( def _traverse_tree( self, X: npt.NDArray, - excluded: Optional[list[int]] = None, - shape: Union[int, tuple[int, ...]] = 1, + excluded: list[int] | None = None, + shape: int | tuple[int, ...] = 1, ) -> npt.NDArray: """ Traverse the tree starting from the root node given an (un)observed point.