diff --git a/pytensor/tensor/slinalg.py b/pytensor/tensor/slinalg.py index 68d056fdc0..35a7c62e22 100644 --- a/pytensor/tensor/slinalg.py +++ b/pytensor/tensor/slinalg.py @@ -8,6 +8,7 @@ import scipy.linalg as scipy_linalg from numpy.exceptions import ComplexWarning from scipy.linalg import get_lapack_funcs +from scipy.linalg._misc import LinAlgError, LinAlgWarning import pytensor from pytensor import ifelse @@ -384,15 +385,28 @@ def make_node(self, *inputs): return Apply(self, [A, b], [out]) def perform(self, node, inputs, output_storage): - C, b = inputs - rval = scipy_linalg.cho_solve( - (C, self.lower), - b, - check_finite=self.check_finite, - overwrite_b=self.overwrite_b, - ) + c, b = inputs + + (potrs,) = get_lapack_funcs(("potrs",), (c, b)) - output_storage[0][0] = rval + if self.check_finite and not (np.isfinite(c).all() and np.isfinite(b).all()): + raise ValueError("array must not contain infs or NaNs") + + if c.ndim != 2 or c.shape[0] != c.shape[1]: + raise ValueError("The factored matrix c is not square.") + if c.shape[1] != b.shape[0]: + raise ValueError(f"incompatible dimensions ({c.shape} and {b.shape})") + + # Quick return for empty arrays + if b.size == 0: + output_storage[0][0] = np.empty_like(b, dtype=potrs.dtype) + return + + x, info = potrs(c, b, lower=self.lower, overwrite_b=self.overwrite_b) + if info != 0: + raise ValueError(f"illegal value in {-info}th argument of internal potrs") + + output_storage[0][0] = x def L_op(self, *args, **kwargs): # TODO: Base impl should work, let's try it @@ -696,9 +710,27 @@ def inplace_on_inputs(self, allowed_inplace_inputs: list[int]) -> "Op": def perform(self, node, inputs, outputs): A = inputs[0] - LU, p = scipy_linalg.lu_factor( - A, overwrite_a=self.overwrite_a, check_finite=self.check_finite - ) + # Quick return for empty arrays + if A.size == 0: + outputs[0][0] = np.empty_like(A) + outputs[1][0] = np.arange(0, dtype=np.int32) + return + + if self.check_finite and not np.isfinite(A).all(): + raise ValueError("array must not contain infs or NaNs") + + (getrf,) = get_lapack_funcs(("getrf",), (A,)) + LU, p, info = getrf(A, overwrite_a=self.overwrite_a) + if info < 0: + raise ValueError( + f"illegal value in {-info}th argument of internal getrf (lu_factor)" + ) + if info > 0: + warnings.warn( + f"Diagonal number {info} is exactly zero. Singular matrix.", + LinAlgWarning, + stacklevel=2, + ) outputs[0][0] = LU outputs[1][0] = p @@ -865,15 +897,51 @@ def __init__(self, *, unit_diagonal=False, **kwargs): def perform(self, node, inputs, outputs): A, b = inputs - outputs[0][0] = scipy_linalg.solve_triangular( - A, - b, - lower=self.lower, - trans=0, - unit_diagonal=self.unit_diagonal, - check_finite=self.check_finite, - overwrite_b=self.overwrite_b, - ) + + if self.check_finite and not (np.isfinite(A).all() and np.isfinite(b).all()): + raise ValueError("array must not contain infs or NaNs") + + if len(A.shape) != 2 or A.shape[0] != A.shape[1]: + raise ValueError("expected square matrix") + + if A.shape[0] != b.shape[0]: + raise ValueError(f"shapes of a {A.shape} and b {b.shape} are incompatible") + + (trtrs,) = get_lapack_funcs(("trtrs",), (A, b)) + + # Quick return for empty arrays + if b.size == 0: + outputs[0][0] = np.empty_like(b, dtype=trtrs.dtype) + return + + if A.flags["F_CONTIGUOUS"]: + x, info = trtrs( + A, + b, + overwrite_b=self.overwrite_b, + lower=self.lower, + trans=0, + unitdiag=self.unit_diagonal, + ) + else: + # transposed system is solved since trtrs expects Fortran ordering + x, info = trtrs( + A.T, + b, + overwrite_b=self.overwrite_b, + lower=not self.lower, + trans=1, + unitdiag=self.unit_diagonal, + ) + + if info > 0: + raise LinAlgError( + f"singular matrix: resolution failed at diagonal {info-1}" + ) + elif info < 0: + raise ValueError(f"illegal value in {-info}-th argument of internal trtrs") + + outputs[0][0] = x def L_op(self, inputs, outputs, output_gradients): res = super().L_op(inputs, outputs, output_gradients) diff --git a/tests/tensor/test_slinalg.py b/tests/tensor/test_slinalg.py index a82307a612..f60c10f764 100644 --- a/tests/tensor/test_slinalg.py +++ b/tests/tensor/test_slinalg.py @@ -513,6 +513,31 @@ def solve_op(A, b): utt.verify_grad(solve_op, [A_val, b_val], 3, rng, eps=eps) + def test_solve_triangular_empty(self): + rng = np.random.default_rng(utt.fetch_seed()) + A = pt.tensor("A", shape=(5, 5)) + b = pt.tensor("b", shape=(5, 0)) + + A_val = rng.random((5, 5)).astype(config.floatX) + b_empty = np.empty([5, 0], dtype=config.floatX) + + A_func = functools.partial(self.A_func, lower=True, unit_diagonal=True) + + x = solve_triangular( + A_func(A), + b, + lower=True, + trans=0, + unit_diagonal=True, + b_ndim=len((5, 0)), + ) + + f = function([A, b], x) + + res = f(A_val, b_empty) + assert res.size == 0 + assert res.dtype == config.floatX + class TestCholeskySolve(utt.InferShapeTester): def setup_method(self): @@ -797,6 +822,18 @@ def test_lu_factor(): ) +def test_lu_factor_empty(): + A = matrix() + f = function([A], lu_factor(A)) + + A_empty = np.empty([0, 0], dtype=config.floatX) + LU, pt_p_idx = f(A_empty) + + assert LU.size == 0 + assert LU.dtype == config.floatX + assert pt_p_idx.size == 0 + + def test_cho_solve(): rng = np.random.default_rng(utt.fetch_seed()) A = matrix() @@ -814,6 +851,21 @@ def test_cho_solve(): ) +def test_cho_solve_empty(): + rng = np.random.default_rng(utt.fetch_seed()) + A = matrix() + b = matrix() + y = cho_solve((A, True), b) + cho_solve_lower_func = function([A, b], y) + + A_empty = np.tril(np.asarray(rng.random((5, 5)), dtype=config.floatX)) + b_empty = np.empty([5, 0], dtype=config.floatX) + + res = cho_solve_lower_func(A_empty, b_empty) + assert res.size == 0 + assert res.dtype == config.floatX + + def test_expm(): rng = np.random.default_rng(utt.fetch_seed()) A = rng.standard_normal((5, 5)).astype(config.floatX)