Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
108 changes: 88 additions & 20 deletions pytensor/tensor/slinalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import scipy.linalg as scipy_linalg
from numpy.exceptions import ComplexWarning
from scipy.linalg import get_lapack_funcs
from scipy.linalg._misc import LinAlgError, LinAlgWarning
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

from scipy.linalg import LinAlgError, LinAlgWarning also works :)


import pytensor
from pytensor import ifelse
Expand Down Expand Up @@ -384,15 +385,28 @@ def make_node(self, *inputs):
return Apply(self, [A, b], [out])

def perform(self, node, inputs, output_storage):
C, b = inputs
rval = scipy_linalg.cho_solve(
(C, self.lower),
b,
check_finite=self.check_finite,
overwrite_b=self.overwrite_b,
)
c, b = inputs

(potrs,) = get_lapack_funcs(("potrs",), (c, b))

output_storage[0][0] = rval
if self.check_finite and not (np.isfinite(c).all() and np.isfinite(b).all()):
raise ValueError("array must not contain infs or NaNs")

if c.ndim != 2 or c.shape[0] != c.shape[1]:
raise ValueError("The factored matrix c is not square.")
if c.shape[1] != b.shape[0]:
raise ValueError(f"incompatible dimensions ({c.shape} and {b.shape})")
Comment on lines +395 to +398
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We don't need to do shape checking in perform, that is handled by make_node

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

not true, shapes may not be static

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

not true, shapes may not be static

FWIW; we'll deprecate in-place modifications of the shape (also dtype and strides) modifications in numpy 2.4

Copy link
Member

@ricardoV94 ricardoV94 Sep 6, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Static here means we don't know the shape until runtime, as in the following graph:

import pytensor
import pytensor.tensor as pt

x = pt.vector("x", shape=(None,))
out = pt.exp(x)

fn = pytensor.function([x], out)
fn([1, 2, 3])
fn([1, 2, 3, 4])  # Allowed to call with different input lengths each time


# Quick return for empty arrays
if b.size == 0:
output_storage[0][0] = np.empty_like(b, dtype=potrs.dtype)
return

x, info = potrs(c, b, lower=self.lower, overwrite_b=self.overwrite_b)
if info != 0:
raise ValueError(f"illegal value in {-info}th argument of internal potrs")

Comment on lines +406 to +408
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd prefer if we returned a matrix of np.nan if info !=0 rather than erroring out. This is what jax does, and it makes it a lot more ergonomic to work with in iterative algorithms.

This might be out of scope for this PR; asking @ricardoV94 for a 2nd opinion

output_storage[0][0] = x

def L_op(self, *args, **kwargs):
# TODO: Base impl should work, let's try it
Expand Down Expand Up @@ -696,9 +710,27 @@ def inplace_on_inputs(self, allowed_inplace_inputs: list[int]) -> "Op":
def perform(self, node, inputs, outputs):
A = inputs[0]

LU, p = scipy_linalg.lu_factor(
A, overwrite_a=self.overwrite_a, check_finite=self.check_finite
)
# Quick return for empty arrays
if A.size == 0:
outputs[0][0] = np.empty_like(A)
outputs[1][0] = np.arange(0, dtype=np.int32)
return

if self.check_finite and not np.isfinite(A).all():
raise ValueError("array must not contain infs or NaNs")

(getrf,) = get_lapack_funcs(("getrf",), (A,))
LU, p, info = getrf(A, overwrite_a=self.overwrite_a)
if info < 0:
raise ValueError(
f"illegal value in {-info}th argument of internal getrf (lu_factor)"
)
if info > 0:
warnings.warn(
f"Diagonal number {info} is exactly zero. Singular matrix.",
LinAlgWarning,
stacklevel=2,
)
Comment on lines +724 to +733
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As above


outputs[0][0] = LU
outputs[1][0] = p
Expand Down Expand Up @@ -865,15 +897,51 @@ def __init__(self, *, unit_diagonal=False, **kwargs):

def perform(self, node, inputs, outputs):
A, b = inputs
outputs[0][0] = scipy_linalg.solve_triangular(
A,
b,
lower=self.lower,
trans=0,
unit_diagonal=self.unit_diagonal,
check_finite=self.check_finite,
overwrite_b=self.overwrite_b,
)

if self.check_finite and not (np.isfinite(A).all() and np.isfinite(b).all()):
raise ValueError("array must not contain infs or NaNs")

if len(A.shape) != 2 or A.shape[0] != A.shape[1]:
raise ValueError("expected square matrix")

if A.shape[0] != b.shape[0]:
raise ValueError(f"shapes of a {A.shape} and b {b.shape} are incompatible")

(trtrs,) = get_lapack_funcs(("trtrs",), (A, b))

# Quick return for empty arrays
if b.size == 0:
outputs[0][0] = np.empty_like(b, dtype=trtrs.dtype)
return

if A.flags["F_CONTIGUOUS"]:
x, info = trtrs(
A,
b,
overwrite_b=self.overwrite_b,
lower=self.lower,
trans=0,
unitdiag=self.unit_diagonal,
)
else:
# transposed system is solved since trtrs expects Fortran ordering
x, info = trtrs(
A.T,
b,
overwrite_b=self.overwrite_b,
lower=not self.lower,
trans=1,
unitdiag=self.unit_diagonal,
)

if info > 0:
raise LinAlgError(
f"singular matrix: resolution failed at diagonal {info-1}"
)
elif info < 0:
raise ValueError(f"illegal value in {-info}-th argument of internal trtrs")
Comment on lines +937 to +942
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As above


outputs[0][0] = x

def L_op(self, inputs, outputs, output_gradients):
res = super().L_op(inputs, outputs, output_gradients)
Expand Down
52 changes: 52 additions & 0 deletions tests/tensor/test_slinalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -513,6 +513,31 @@ def solve_op(A, b):

utt.verify_grad(solve_op, [A_val, b_val], 3, rng, eps=eps)

def test_solve_triangular_empty(self):
rng = np.random.default_rng(utt.fetch_seed())
A = pt.tensor("A", shape=(5, 5))
b = pt.tensor("b", shape=(5, 0))

A_val = rng.random((5, 5)).astype(config.floatX)
b_empty = np.empty([5, 0], dtype=config.floatX)

A_func = functools.partial(self.A_func, lower=True, unit_diagonal=True)

x = solve_triangular(
A_func(A),
b,
lower=True,
trans=0,
unit_diagonal=True,
b_ndim=len((5, 0)),
)

f = function([A, b], x)

res = f(A_val, b_empty)
assert res.size == 0
assert res.dtype == config.floatX


class TestCholeskySolve(utt.InferShapeTester):
def setup_method(self):
Expand Down Expand Up @@ -797,6 +822,18 @@ def test_lu_factor():
)


def test_lu_factor_empty():
A = matrix()
f = function([A], lu_factor(A))

A_empty = np.empty([0, 0], dtype=config.floatX)
LU, pt_p_idx = f(A_empty)

assert LU.size == 0
assert LU.dtype == config.floatX
assert pt_p_idx.size == 0


def test_cho_solve():
rng = np.random.default_rng(utt.fetch_seed())
A = matrix()
Expand All @@ -814,6 +851,21 @@ def test_cho_solve():
)


def test_cho_solve_empty():
rng = np.random.default_rng(utt.fetch_seed())
A = matrix()
b = matrix()
y = cho_solve((A, True), b)
cho_solve_lower_func = function([A, b], y)

A_empty = np.tril(np.asarray(rng.random((5, 5)), dtype=config.floatX))
b_empty = np.empty([5, 0], dtype=config.floatX)

res = cho_solve_lower_func(A_empty, b_empty)
assert res.size == 0
assert res.dtype == config.floatX


def test_expm():
rng = np.random.default_rng(utt.fetch_seed())
A = rng.standard_normal((5, 5)).astype(config.floatX)
Expand Down