Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ ci:

repos:
- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.13.3
rev: v0.14.0
hooks:
- id: ruff
args: ["--fix", "--output-format=full"]
Expand Down
7 changes: 3 additions & 4 deletions pymc_bart/bart.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@

import warnings
from multiprocessing import Manager
from typing import Optional

import numpy as np
import numpy.typing as npt
Expand Down Expand Up @@ -130,9 +129,9 @@ def __new__(
alpha: float = 0.95,
beta: float = 2.0,
response: str = "constant",
split_prior: Optional[npt.NDArray] = None,
split_rules: Optional[list[SplitRule]] = None,
separate_trees: Optional[bool] = False,
split_prior: npt.NDArray | None = None,
split_rules: list[SplitRule] | None = None,
separate_trees: bool | None = False,
**kwargs,
):
if response in ["linear", "mix"]:
Expand Down
13 changes: 6 additions & 7 deletions pymc_bart/pgbart.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Optional, Union

import numpy as np
import numpy.typing as npt
Expand Down Expand Up @@ -128,7 +127,7 @@ def __init__( # noqa: PLR0912, PLR0915
vars: list[pm.Distribution] | None = None,
num_particles: int = 10,
batch: tuple[float, float] = (0.1, 0.1),
model: Optional[Model] = None,
model: Model | None = None,
initial_point: PointType | None = None,
compile_kwargs: dict | None = None,
**kwargs, # Accept additional kwargs for compound sampling
Expand Down Expand Up @@ -445,7 +444,7 @@ def __init__(self, shape: tuple[int, ...]) -> None:
self.mean = np.zeros(shape) # running mean
self.m_2 = np.zeros(shape) # running second moment

def update(self, new_value: npt.NDArray) -> Union[float, npt.NDArray]:
def update(self, new_value: npt.NDArray) -> float | npt.NDArray:
self.count = self.count + 1
self.mean, self.m_2, std = _update(self.count, self.mean, self.m_2, new_value)
return fast_mean(std)
Expand All @@ -457,7 +456,7 @@ def _update(
mean: npt.NDArray,
m_2: npt.NDArray,
new_value: npt.NDArray,
) -> tuple[npt.NDArray, npt.NDArray, Union[float, npt.NDArray]]:
) -> tuple[npt.NDArray, npt.NDArray, float | npt.NDArray]:
delta = new_value - mean
mean += delta / count
delta2 = new_value - mean
Expand All @@ -477,7 +476,7 @@ def __init__(self, alpha_vec: npt.NDArray) -> None:
"""
self.enu = list(enumerate(np.cumsum(alpha_vec / alpha_vec.sum())))

def rvs(self) -> Union[int, tuple[int, float]]:
def rvs(self) -> int | tuple[int, float]:
rnd: float = np.random.random()
for i, val in self.enu:
if rnd <= val:
Expand Down Expand Up @@ -587,7 +586,7 @@ def draw_leaf_value(
norm: npt.NDArray,
shape: int,
response: str,
) -> tuple[npt.NDArray, Optional[npt.NDArray]]:
) -> tuple[npt.NDArray, npt.NDArray | None]:
"""Draw Gaussian distributed leaf values."""
linear_params = None
mu_mean: npt.NDArray
Expand All @@ -605,7 +604,7 @@ def draw_leaf_value(


@njit
def fast_mean(ari: npt.NDArray) -> Union[float, npt.NDArray]:
def fast_mean(ari: npt.NDArray) -> float | npt.NDArray:
"""Use Numba to speed up the computation of the mean."""
if ari.ndim == 1:
count = ari.shape[0]
Expand Down
19 changes: 9 additions & 10 deletions pymc_bart/tree.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@

from collections.abc import Generator
from functools import lru_cache
from typing import Optional, Union

import numpy as np
import numpy.typing as npt
Expand All @@ -40,9 +39,9 @@ def __init__(
self,
value: npt.NDArray = np.array([-1.0]),
nvalue: int = 0,
idx_data_points: Optional[npt.NDArray[np.int_]] = None,
idx_data_points: npt.NDArray[np.int_] | None = None,
idx_split_variable: int = -1,
linear_params: Optional[list[npt.NDArray]] = None,
linear_params: list[npt.NDArray] | None = None,
) -> None:
self.value = value
self.nvalue = nvalue
Expand All @@ -55,9 +54,9 @@ def new_leaf_node(
cls,
value: npt.NDArray,
nvalue: int = 0,
idx_data_points: Optional[npt.NDArray[np.int_]] = None,
idx_data_points: npt.NDArray[np.int_] | None = None,
idx_split_variable: int = -1,
linear_params: Optional[list[npt.NDArray]] = None,
linear_params: list[npt.NDArray] | None = None,
) -> "Node":
return cls(
value=value,
Expand Down Expand Up @@ -124,7 +123,7 @@ def __init__(
tree_structure: dict[int, Node],
output: npt.NDArray,
split_rules: list[SplitRule],
idx_leaf_nodes: Optional[list[int]] = None,
idx_leaf_nodes: list[int] | None = None,
) -> None:
self.tree_structure = tree_structure
self.idx_leaf_nodes = idx_leaf_nodes
Expand All @@ -135,7 +134,7 @@ def __init__(
def new_tree(
cls,
leaf_node_value: npt.NDArray,
idx_data_points: Optional[npt.NDArray[np.int_]],
idx_data_points: npt.NDArray[np.int_] | None,
num_observations: int,
shape: int,
split_rules: list[SplitRule],
Expand Down Expand Up @@ -234,7 +233,7 @@ def _predict(self) -> npt.NDArray:
def predict(
self,
x: npt.NDArray,
excluded: Optional[list[int]] = None,
excluded: list[int] | None = None,
shape: int = 1,
) -> npt.NDArray:
"""
Expand All @@ -260,8 +259,8 @@ def predict(
def _traverse_tree(
self,
X: npt.NDArray,
excluded: Optional[list[int]] = None,
shape: Union[int, tuple[int, ...]] = 1,
excluded: list[int] | None = None,
shape: int | tuple[int, ...] = 1,
) -> npt.NDArray:
"""
Traverse the tree starting from the root node given an (un)observed point.
Expand Down
Loading