-
Notifications
You must be signed in to change notification settings - Fork 139
Use LAPACK functions for cho_solve
, lu_factor
, solve_triangular
#1605
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
bfdcabf
135b8d9
58a79e8
a536532
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -8,6 +8,7 @@ | |
import scipy.linalg as scipy_linalg | ||
from numpy.exceptions import ComplexWarning | ||
from scipy.linalg import get_lapack_funcs | ||
from scipy.linalg._misc import LinAlgError, LinAlgWarning | ||
|
||
import pytensor | ||
from pytensor import ifelse | ||
|
@@ -384,15 +385,28 @@ def make_node(self, *inputs): | |
return Apply(self, [A, b], [out]) | ||
|
||
def perform(self, node, inputs, output_storage): | ||
C, b = inputs | ||
rval = scipy_linalg.cho_solve( | ||
(C, self.lower), | ||
b, | ||
check_finite=self.check_finite, | ||
overwrite_b=self.overwrite_b, | ||
) | ||
c, b = inputs | ||
|
||
(potrs,) = get_lapack_funcs(("potrs",), (c, b)) | ||
|
||
output_storage[0][0] = rval | ||
if self.check_finite and not (np.isfinite(c).all() and np.isfinite(b).all()): | ||
raise ValueError("array must not contain infs or NaNs") | ||
|
||
if c.ndim != 2 or c.shape[0] != c.shape[1]: | ||
raise ValueError("The factored matrix c is not square.") | ||
if c.shape[1] != b.shape[0]: | ||
raise ValueError(f"incompatible dimensions ({c.shape} and {b.shape})") | ||
Comment on lines
+395
to
+398
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We don't need to do shape checking in There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. not true, shapes may not be static There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
FWIW; we'll deprecate in-place modifications of the shape (also dtype and strides) modifications in numpy 2.4 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Static here means we don't know the shape until runtime, as in the following graph: import pytensor
import pytensor.tensor as pt
x = pt.vector("x", shape=(None,))
out = pt.exp(x)
fn = pytensor.function([x], out)
fn([1, 2, 3])
fn([1, 2, 3, 4]) # Allowed to call with different input lengths each time |
||
|
||
# Quick return for empty arrays | ||
if b.size == 0: | ||
output_storage[0][0] = np.empty_like(b, dtype=potrs.dtype) | ||
return | ||
|
||
x, info = potrs(c, b, lower=self.lower, overwrite_b=self.overwrite_b) | ||
if info != 0: | ||
raise ValueError(f"illegal value in {-info}th argument of internal potrs") | ||
|
||
Comment on lines
+406
to
+408
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'd prefer if we returned a matrix of This might be out of scope for this PR; asking @ricardoV94 for a 2nd opinion |
||
output_storage[0][0] = x | ||
|
||
def L_op(self, *args, **kwargs): | ||
# TODO: Base impl should work, let's try it | ||
|
@@ -696,9 +710,27 @@ def inplace_on_inputs(self, allowed_inplace_inputs: list[int]) -> "Op": | |
def perform(self, node, inputs, outputs): | ||
A = inputs[0] | ||
|
||
LU, p = scipy_linalg.lu_factor( | ||
A, overwrite_a=self.overwrite_a, check_finite=self.check_finite | ||
) | ||
# Quick return for empty arrays | ||
if A.size == 0: | ||
outputs[0][0] = np.empty_like(A) | ||
outputs[1][0] = np.arange(0, dtype=np.int32) | ||
return | ||
|
||
if self.check_finite and not np.isfinite(A).all(): | ||
raise ValueError("array must not contain infs or NaNs") | ||
|
||
(getrf,) = get_lapack_funcs(("getrf",), (A,)) | ||
LU, p, info = getrf(A, overwrite_a=self.overwrite_a) | ||
if info < 0: | ||
raise ValueError( | ||
f"illegal value in {-info}th argument of internal getrf (lu_factor)" | ||
) | ||
if info > 0: | ||
warnings.warn( | ||
f"Diagonal number {info} is exactly zero. Singular matrix.", | ||
LinAlgWarning, | ||
stacklevel=2, | ||
) | ||
Comment on lines
+724
to
+733
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. As above |
||
|
||
outputs[0][0] = LU | ||
outputs[1][0] = p | ||
|
@@ -865,15 +897,51 @@ def __init__(self, *, unit_diagonal=False, **kwargs): | |
|
||
def perform(self, node, inputs, outputs): | ||
A, b = inputs | ||
outputs[0][0] = scipy_linalg.solve_triangular( | ||
A, | ||
b, | ||
lower=self.lower, | ||
trans=0, | ||
unit_diagonal=self.unit_diagonal, | ||
check_finite=self.check_finite, | ||
overwrite_b=self.overwrite_b, | ||
) | ||
|
||
if self.check_finite and not (np.isfinite(A).all() and np.isfinite(b).all()): | ||
raise ValueError("array must not contain infs or NaNs") | ||
|
||
if len(A.shape) != 2 or A.shape[0] != A.shape[1]: | ||
raise ValueError("expected square matrix") | ||
|
||
if A.shape[0] != b.shape[0]: | ||
raise ValueError(f"shapes of a {A.shape} and b {b.shape} are incompatible") | ||
|
||
(trtrs,) = get_lapack_funcs(("trtrs",), (A, b)) | ||
|
||
# Quick return for empty arrays | ||
if b.size == 0: | ||
outputs[0][0] = np.empty_like(b, dtype=trtrs.dtype) | ||
return | ||
|
||
if A.flags["F_CONTIGUOUS"]: | ||
x, info = trtrs( | ||
A, | ||
b, | ||
overwrite_b=self.overwrite_b, | ||
lower=self.lower, | ||
trans=0, | ||
unitdiag=self.unit_diagonal, | ||
) | ||
else: | ||
# transposed system is solved since trtrs expects Fortran ordering | ||
x, info = trtrs( | ||
A.T, | ||
b, | ||
overwrite_b=self.overwrite_b, | ||
lower=not self.lower, | ||
trans=1, | ||
unitdiag=self.unit_diagonal, | ||
) | ||
|
||
if info > 0: | ||
raise LinAlgError( | ||
f"singular matrix: resolution failed at diagonal {info-1}" | ||
) | ||
elif info < 0: | ||
raise ValueError(f"illegal value in {-info}-th argument of internal trtrs") | ||
Comment on lines
+937
to
+942
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. As above |
||
|
||
outputs[0][0] = x | ||
|
||
def L_op(self, inputs, outputs, output_gradients): | ||
res = super().L_op(inputs, outputs, output_gradients) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
from scipy.linalg import LinAlgError, LinAlgWarning
also works :)