Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 2 additions & 6 deletions pymatsolver/direct/mumps.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,22 +28,18 @@ class Mumps(Base):
The relative tolerance to check against for accuracy.
check_atol : float, optional
The absolute tolerance to check against for accuracy.
accuracy_tol : float, optional
Relative accuracy tolerance.
.. deprecated:: 0.3.0
`accuracy_tol` will be removed in pymatsolver 0.4.0. Use `check_rtol` and `check_atol` instead.
**kwargs
Extra keyword arguments. If there are any left here a warning will be raised.
"""
_transposed = False

def __init__(self, A, ordering=None, is_symmetric=None, is_positive_definite=False, check_accuracy=False, check_rtol=1e-6, check_atol=0, accuracy_tol=None, **kwargs):
def __init__(self, A, ordering=None, is_symmetric=None, is_positive_definite=False, check_accuracy=False, check_rtol=1e-6, check_atol=0, **kwargs):
if not _available:
raise ImportError(
"The Mumps solver requires the python-mumps package to be installed."
)
is_hermitian = kwargs.pop('is_hermitian', False)
super().__init__(A, is_symmetric=is_symmetric, is_positive_definite=is_positive_definite, is_hermitian=is_hermitian, check_accuracy=check_accuracy, check_rtol=check_rtol, check_atol=check_atol, accuracy_tol=accuracy_tol, **kwargs)
super().__init__(A, is_symmetric=is_symmetric, is_positive_definite=is_positive_definite, is_hermitian=is_hermitian, check_accuracy=check_accuracy, check_rtol=check_rtol, check_atol=check_atol, **kwargs)
if ordering is None:
ordering = "metis"
self.ordering = ordering
Expand Down
8 changes: 2 additions & 6 deletions pymatsolver/direct/pardiso.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,20 +32,16 @@ class Pardiso(Base):
The relative tolerance to check against for accuracy.
check_atol : float, optional
The absolute tolerance to check against for accuracy.
accuracy_tol : float, optional
Relative accuracy tolerance.
.. deprecated:: 0.3.0
`accuracy_tol` will be removed in pymatsolver 0.4.0. Use `check_rtol` and `check_atol` instead.
**kwargs
Extra keyword arguments. If there are any left here a warning will be raised.
"""

_transposed = False

def __init__(self, A, n_threads=None, is_symmetric=None, is_positive_definite=False, is_hermitian=None, check_accuracy=False, check_rtol=1e-6, check_atol=0, accuracy_tol=None, **kwargs):
def __init__(self, A, n_threads=None, is_symmetric=None, is_positive_definite=False, is_hermitian=None, check_accuracy=False, check_rtol=1e-6, check_atol=0, **kwargs):
if not _available:
raise ImportError("Pardiso solver requires the pydiso package to be installed.")
super().__init__(A, is_symmetric=is_symmetric, is_positive_definite=is_positive_definite, is_hermitian=is_hermitian, check_accuracy=check_accuracy, check_rtol=check_rtol, check_atol=check_atol, accuracy_tol=accuracy_tol, **kwargs)
super().__init__(A, is_symmetric=is_symmetric, is_positive_definite=is_positive_definite, is_hermitian=is_hermitian, check_accuracy=check_accuracy, check_rtol=check_rtol, check_atol=check_atol, **kwargs)
self.solver = MKLPardisoSolver(
self.A,
matrix_type=self._matrixType(),
Expand Down
18 changes: 4 additions & 14 deletions pymatsolver/iterative.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,6 @@ class BiCGJacobi(Base):
----------
A : matrix
The matrix to solve, must have a ``diagonal()`` method.
symmetric: boolean, optional
.. deprecated:: 0.3.0
`symmetric` is deprecated. It is unused, and will be removed in pymatsolver 0.4.0.
maxiter : int, optional
The maximum number of BiCG iterations to perform.
rtol : float, optional
Expand All @@ -36,21 +33,14 @@ class BiCGJacobi(Base):
The relative tolerance to check against for accuracy.
check_atol : float, optional
The absolute tolerance to check against for accuracy.
accuracy_tol : float, optional
Relative accuracy tolerance.
.. deprecated:: 0.3.0
`accuracy_tol` will be removed in pymatsolver 0.4.0. Use `check_rtol` and `check_atol` instead.
**kwargs
Extra keyword arguments passed to the base class.
"""

def __init__(self, A, symmetric=None, maxiter=1000, rtol=1E-6, atol=0.0, check_accuracy=False, check_rtol=1e-6, check_atol=0, accuracy_tol=None, **kwargs):
if symmetric is not None:
warnings.warn(
"The symmetric keyword argument is unused and is deprecated. It will be removed in pymatsolver 0.4.0.",
FutureWarning, stacklevel=2
)
super().__init__(A, check_accuracy=check_accuracy, check_rtol=check_rtol, check_atol=check_atol, accuracy_tol=accuracy_tol, **kwargs)
def __init__(self, A, maxiter=1000, rtol=1E-6, atol=0.0, check_accuracy=False, check_rtol=1e-6, check_atol=0, **kwargs):
if "symmetric" in kwargs:
raise TypeError("The symmetric keyword argument was been removed in pymatsolver 0.4.0.")
super().__init__(A, check_accuracy=check_accuracy, check_rtol=check_rtol, check_atol=check_atol, **kwargs)
self._factored = False
self.maxiter = maxiter
self.rtol = rtol
Expand Down
59 changes: 11 additions & 48 deletions pymatsolver/solvers.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,10 +36,6 @@ class Base(ABC):
The relative tolerance to check against for accuracy.
check_atol : float, optional
The absolute tolerance to check against for accuracy.
accuracy_tol : float, optional
Relative accuracy tolerance.
.. deprecated:: 0.3.0
`accuracy_tol` will be removed in pymatsolver 0.4.0. Use `check_rtol` and `check_atol` instead.
**kwargs
Extra keyword arguments. If there are any left here a warning will be raised.
"""
Expand All @@ -50,7 +46,7 @@ class Base(ABC):
_is_conjugate = False

def __init__(
self, A, is_symmetric=None, is_positive_definite=False, is_hermitian=None, check_accuracy=False, check_rtol=1e-6, check_atol=0, accuracy_tol=None, **kwargs
self, A, is_symmetric=None, is_positive_definite=False, is_hermitian=None, check_accuracy=False, check_rtol=1e-6, check_atol=0, **kwargs
):
# don't make any assumptions on what A is, let the individual solvers handle that
shape = A.shape
Expand All @@ -61,13 +57,8 @@ def __init__(
self._A = A
self._dtype = np.dtype(A.dtype)

if accuracy_tol is not None:
warnings.warn(
"accuracy_tol is deprecated and will be removed in v0.4.0, use check_rtol and check_atol.",
FutureWarning,
stacklevel=3
)
check_rtol = accuracy_tol
if 'accuracy_tol' in kwargs:
raise TypeError("'accuracy_tol' was removed in v0.4.0, use 'check_rtol' and 'check_atol'.")

self.check_accuracy = check_accuracy
self.check_rtol = check_rtol
Expand Down Expand Up @@ -341,14 +332,6 @@ def solve(self, rhs):
rhs = rhs.conjugate()
x = self._solve_single(rhs)
else:
if ndim == 2 and rhs.shape[-1] == 1:
warnings.warn(
"In Future pymatsolver v0.4.0, passing a vector of shape (n, 1) to the solve method "
"will return an array with shape (n, 1), instead of always returning a flattened array. "
"This is to be consistent with numpy.linalg.solve broadcasting.",
FutureWarning,
stacklevel=2
)
if rhs.shape[-2] != n:
raise ValueError(f'Second to last dimension should be {n}, got {rhs.shape}')
do_broadcast = rhs.ndim > 2
Expand Down Expand Up @@ -377,10 +360,6 @@ def solve(self, rhs):
if self.check_accuracy:
self._compute_accuracy(rhs, x)

#TODO remove this in v0.4.0.
if x.size == n:
x = x.reshape(-1)

if self._is_conjugate:
x = x.conjugate()
return x
Expand Down Expand Up @@ -449,15 +428,11 @@ class Diagonal(Base):
The relative tolerance to check against for accuracy.
check_atol : float, optional
The absolute tolerance to check against for accuracy.
accuracy_tol : float, optional
Relative accuracy tolerance.
.. deprecated:: 0.3.0
`accuracy_tol` will be removed in pymatsolver 0.4.0. Use `check_rtol` and `check_atol` instead.
**kwargs
Extra keyword arguments passed to the base class.
"""

def __init__(self, A, check_accuracy=False, check_rtol=1e-6, check_atol=0, accuracy_tol=None, **kwargs):
def __init__(self, A, check_accuracy=False, check_rtol=1e-6, check_atol=0, **kwargs):
try:
self._diagonal = np.asarray(A.diagonal())
if not np.all(self._diagonal):
Expand All @@ -469,7 +444,7 @@ def __init__(self, A, check_accuracy=False, check_rtol=1e-6, check_atol=0, accur
is_hermitian = kwargs.pop("is_hermitian", None)
is_positive_definite = kwargs.pop("is_positive_definite", None)
super().__init__(
A, is_symmetric=True, is_hermitian=False, check_accuracy=check_accuracy, check_rtol=check_rtol, check_atol=check_atol, accuracy_tol=accuracy_tol, **kwargs
A, is_symmetric=True, is_hermitian=False, check_accuracy=check_accuracy, check_rtol=check_rtol, check_atol=check_atol, **kwargs
)
if is_positive_definite is None:
if self.is_real:
Expand Down Expand Up @@ -510,23 +485,19 @@ class Triangle(Base):
The relative tolerance to check against for accuracy.
check_atol : float, optional
The absolute tolerance to check against for accuracy.
accuracy_tol : float, optional
Relative accuracy tolerance.
.. deprecated:: 0.3.0
`accuracy_tol` will be removed in pymatsolver 0.4.0. Use `check_rtol` and `check_atol` instead.
**kwargs
Extra keyword arguments passed to the base class.
"""

def __init__(self, A, lower=True, check_accuracy=False, check_rtol=1e-6, check_atol=0, accuracy_tol=None, **kwargs):
def __init__(self, A, lower=True, check_accuracy=False, check_rtol=1e-6, check_atol=0, **kwargs):
# pop off unneeded keyword arguments.
is_hermitian = kwargs.pop("is_hermitian", False)
is_symmetric = kwargs.pop("is_symmetric", False)
is_positive_definite = kwargs.pop("is_positive_definite", False)
if not (sp.issparse(A) and A.format in ['csr', 'csc']):
A = sp.csc_matrix(A)
A.sum_duplicates()
super().__init__(A, is_hermitian=is_hermitian, is_symmetric=is_symmetric, is_positive_definite=is_positive_definite, check_accuracy=check_accuracy, check_rtol=check_rtol, check_atol=check_atol, accuracy_tol=accuracy_tol, **kwargs)
super().__init__(A, is_hermitian=is_hermitian, is_symmetric=is_symmetric, is_positive_definite=is_positive_definite, check_accuracy=check_accuracy, check_rtol=check_rtol, check_atol=check_atol, **kwargs)

self.lower = lower

Expand Down Expand Up @@ -565,17 +536,13 @@ class Forward(Triangle):
The relative tolerance to check against for accuracy.
check_atol : float, optional
The absolute tolerance to check against for accuracy.
accuracy_tol : float, optional
Relative accuracy tolerance.
.. deprecated:: 0.3.0
`accuracy_tol` will be removed in pymatsolver 0.4.0. Use `check_rtol` and `check_atol` instead.
**kwargs
Extra keyword arguments passed to the base class.
"""

def __init__(self, A, check_accuracy=False, check_rtol=1e-6, check_atol=0, accuracy_tol=None, **kwargs):
def __init__(self, A, check_accuracy=False, check_rtol=1e-6, check_atol=0, **kwargs):
kwargs.pop("lower", None)
super().__init__(A, lower=True, check_accuracy=check_accuracy, check_rtol=check_rtol, check_atol=check_atol, accuracy_tol=accuracy_tol, **kwargs)
super().__init__(A, lower=True, check_accuracy=check_accuracy, check_rtol=check_rtol, check_atol=check_atol, **kwargs)


class Backward(Triangle):
Expand All @@ -591,19 +558,15 @@ class Backward(Triangle):
The relative tolerance to check against for accuracy.
check_atol : float, optional
The absolute tolerance to check against for accuracy.
accuracy_tol : float, optional
Relative accuracy tolerance.
.. deprecated:: 0.3.0
`accuracy_tol` will be removed in pymatsolver 0.4.0. Use `check_rtol` and `check_atol` instead.
**kwargs
Extra keyword arguments passed to the base class.
"""

_transpose_class = Forward

def __init__(self, A, check_accuracy=False, check_rtol=1e-6, check_atol=0, accuracy_tol=None, **kwargs):
def __init__(self, A, check_accuracy=False, check_rtol=1e-6, check_atol=0, **kwargs):
kwargs.pop("lower", None)
super().__init__(A, lower=False, check_accuracy=check_accuracy, check_rtol=check_rtol, check_atol=check_atol, accuracy_tol=accuracy_tol, **kwargs)
super().__init__(A, lower=False, check_accuracy=check_accuracy, check_rtol=check_rtol, check_atol=check_atol, **kwargs)


Forward._transpose_class = Backward
32 changes: 5 additions & 27 deletions pymatsolver/wrappers.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,9 +72,9 @@ def wrap_direct(fun, factorize=True, name=None):
>>> SolverLU = pymatsolver.WrapDirect(splu, factorize=True)
"""

def __init__(self, A, check_accuracy=False, check_rtol=1E-6, check_atol=0, accuracy_tol=None, **kwargs):
def __init__(self, A, check_accuracy=False, check_rtol=1E-6, check_atol=0, **kwargs):
Base.__init__(
self, A, check_accuracy=check_accuracy, check_rtol=check_rtol, check_atol=check_atol, accuracy_tol=accuracy_tol,
self, A, check_accuracy=check_accuracy, check_rtol=check_rtol, check_atol=check_atol,
)
self.kwargs = kwargs
if factorize:
Expand Down Expand Up @@ -137,34 +137,20 @@ def clean(self):
The relative tolerance to check against for accuracy.
check_atol : float, optional
The absolute tolerance to check against for accuracy.
accuracy_tol : float, optional
Relative accuracy tolerance.
.. deprecated:: 0.3.0
`accuracy_tol` will be removed in pymatsolver 0.4.0. Use `check_rtol` and `check_atol` instead.
**kwargs
Extra keyword arguments which will attempted to be passed to the wrapped function.
"""
return WrappedClass


def wrap_iterative(fun, check_accuracy=None, accuracy_tol=None, name=None):
def wrap_iterative(fun, name=None):
"""
Wraps an iterative Solver.

Parameters
----------
fun : callable
The iterative Solver function.
check_accuracy : bool, optional
.. deprecated:: 0.3.0
The `check_accuracy` argument was previously unused. This will be
removed in a `pymatsolver` 0.4.0. Pass the relevant accuracy check parameters
to the wrapped class.
accuracy_tol : bool, optional
.. deprecated:: 0.3.0
The `check_accuracy` argument was previously unused. This will be
removed in a `pymatsolver` 0.4.0. Pass the relevant accuracy check parameters
to the wrapped class.
name : string, optional
The name of the wrapper class to construct. Defaults to the name of `fun`.

Expand All @@ -185,14 +171,10 @@ def wrap_iterative(fun, check_accuracy=None, accuracy_tol=None, name=None):
>>> SolverCG = pymatsolver.WrapIterative(cg)

"""
if check_accuracy is not None or accuracy_tol is not None:
warnings.warn('check_accuracy and accuracy_tol were unused and are now deprecated. They '
'will be removed in pymatsolver v0.4.0. Please pass the keyword arguments `check_rtol` '
'and check_atol directly to the wrapped solver class.', FutureWarning, stacklevel=2)

def __init__(self, A, check_accuracy=False, check_rtol=1E-6, check_atol=0, accuracy_tol=None, **kwargs):
def __init__(self, A, check_accuracy=False, check_rtol=1E-6, check_atol=0, **kwargs):
Base.__init__(
self, A, check_accuracy=check_accuracy, check_rtol=check_rtol, check_atol=check_atol, accuracy_tol=accuracy_tol,
self, A, check_accuracy=check_accuracy, check_rtol=check_rtol, check_atol=check_atol,
)
self.kwargs = kwargs

Expand Down Expand Up @@ -245,10 +227,6 @@ def _solve_multiple(self, rhs):
The relative tolerance to check against for accuracy.
check_atol : float, optional
The absolute tolerance to check against for accuracy.
accuracy_tol : float, optional
Relative accuracy tolerance.
.. deprecated:: 0.3.0
`accuracy_tol` will be removed in pymatsolver 0.4.0. Use `check_rtol` and `check_atol` instead.
**kwargs
Extra keyword arguments which will attempted to be passed to the wrapped function.
"""
Expand Down
12 changes: 7 additions & 5 deletions tests/test_Basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,16 +56,22 @@ def test_basic_solve():
Ainv = IdentitySolver(np.eye(4))

rhs = np.arange(4)
rhs1d = np.arange(4).reshape(4, 1)
rhs2d = np.arange(8).reshape(4, 2)
rhs3d = np.arange(24).reshape(3, 4, 2)

npt.assert_equal(Ainv @ rhs, rhs)
npt.assert_equal(Ainv @ rhs1d, rhs1d)
npt.assert_equal(Ainv @ rhs2d, rhs2d)
npt.assert_equal(Ainv @ rhs3d, rhs3d)

npt.assert_equal(rhs @ Ainv, rhs)
npt.assert_equal(rhs * Ainv, rhs)


npt.assert_equal(rhs1d.T @ Ainv, rhs1d.T)
npt.assert_equal(rhs1d.T * Ainv, rhs1d.T)

npt.assert_equal(rhs2d.T @ Ainv, rhs2d.T)
npt.assert_equal(rhs2d.T * Ainv, rhs2d.T)

Expand All @@ -82,7 +88,7 @@ def test_errors_and_warnings():
with pytest.raises(ValueError, match="A is not a square matrix."):
IdentitySolver(np.full((3, 5), 1))

with pytest.warns(FutureWarning, match="accuracy_tol is deprecated.*"):
with pytest.raises(TypeError, match=r"'accuracy_tol' was removed.*"):
IdentitySolver(np.full((4, 4), 1), accuracy_tol=0.41)

with pytest.warns(UnusedArgumentWarning, match="Unused keyword arguments.*"):
Expand Down Expand Up @@ -111,10 +117,6 @@ def test_errors_and_warnings():
Ainv = IdentitySolver(np.eye(4, 4))
Ainv @ np.ones((3, 2))

with pytest.warns(FutureWarning, match="In Future pymatsolver v0.4.0, passing a vector.*"):
Ainv = IdentitySolver(np.eye(4, 4))
Ainv @ np.ones((4, 1))

with pytest.raises(NotImplementedError, match="The transpose for the.*"):
Ainv = NotTransposableIdentitySolver(np.eye(4, 4), is_symmetric=False)
Ainv.T
Expand Down
2 changes: 1 addition & 1 deletion tests/test_BicgJacobi.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ def test_solve(test_mat_data, dtype, transpose, symmetric):

def test_errors_and_warnings(test_mat_data):
A, sol = test_mat_data
with pytest.warns(FutureWarning):
with pytest.raises(TypeError, match="The symmetric keyword.*"):
Ainv = BicgJacobi(A, symmetric=True)

with pytest.raises(ValueError):
Expand Down
13 changes: 10 additions & 3 deletions tests/test_Wrappers.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import warnings
import numpy.testing as npt
import numpy as np
import re


@pytest.mark.parametrize("solver_class", [SolverCG, SolverLU])
Expand Down Expand Up @@ -69,12 +70,18 @@ def clean(self):
assert Ainv.solver.A is None


def test_iterative_deprecations():
def test_iterative_removals():

with pytest.warns(FutureWarning, match="check_accuracy and accuracy_tol were unused.*"):
with pytest.raises(
TypeError,
match=re.escape("wrap_iterative() got an unexpected keyword argument 'check_accuracy'")
):
wrap_iterative(lambda a, x: x, check_accuracy=True)

with pytest.warns(FutureWarning, match="check_accuracy and accuracy_tol were unused.*"):
with pytest.raises(
TypeError,
match=re.escape("wrap_iterative() got an unexpected keyword argument 'accuracy_tol'")
):
wrap_iterative(lambda a, x: x, accuracy_tol=1E-3)


Expand Down
Loading