From 442dbd303a263a5ae58ebc67169f649c3c2aa516 Mon Sep 17 00:00:00 2001 From: Nick Johnson <24689722+ntjohnson1@users.noreply.github.com> Date: Fri, 2 Jun 2023 18:43:52 -0400 Subject: [PATCH 1/3] khatrirao: Clear up some debug and use generators --- pyttb/khatrirao.py | 25 +++---------------------- 1 file changed, 3 insertions(+), 22 deletions(-) diff --git a/pyttb/khatrirao.py b/pyttb/khatrirao.py index aded4497..ab6cd631 100644 --- a/pyttb/khatrirao.py +++ b/pyttb/khatrirao.py @@ -44,41 +44,22 @@ def khatrirao(*listOfMatrices, reverse=False): # Error checking on input and set matrix order if reverse == True: listOfMatrices = list(reversed(listOfMatrices)) - ndimsA = [len(matrix.shape) == 2 for matrix in listOfMatrices] - if not np.all(ndimsA): + if not all(len(matrix.shape) == 2 for matrix in listOfMatrices): assert False, "Each argument must be a matrix" ncolFirst = listOfMatrices[0].shape[1] - ncols = [matrix.shape[1] == ncolFirst for matrix in listOfMatrices] - if not np.all(ncols): + if not all(matrix.shape[1] == ncolFirst for matrix in listOfMatrices): assert False, "All matrices must have the same number of columns." # Computation - # print(f'A =\n {listOfMatrices}') P = listOfMatrices[0] - # print(f'size_P = \n{P.shape}') - # print(f'P = \n{P}') if ncolFirst == 1: for i in listOfMatrices[1:]: - # print(f'size_Ai = \n{i.shape}') - # print(f'size_reshape_Ai = \n{np.reshape(i, newshape=(-1, ncolFirst)).shape}') - # print(f'size_P = \n{P.shape}') - # print(f'size_reshape_P = \n{np.reshape(P, newshape=(ncolFirst, -1)).shape}') - P = np.reshape(i, newshape=(-1, ncolFirst)) * np.reshape( - P, newshape=(ncolFirst, -1), order="F" - ) - # print(f'size_P = \n{P.shape}') - # print(f'P = \n{P}') + P = i * np.reshape(P, newshape=(ncolFirst, -1), order="F") else: for i in listOfMatrices[1:]: - # print(f'size_Ai = \n{i.shape}') - # print(f'size_reshape_Ai = \n{np.reshape(i, newshape=(-1, 1, ncolFirst)).shape}') - # print(f'size_P = \n{P.shape}') - # print(f'size_reshape_P = \n{np.reshape(P, newshape=(1, -1, ncolFirst)).shape}') P = np.reshape(i, newshape=(-1, 1, ncolFirst)) * np.reshape( P, newshape=(1, -1, ncolFirst), order="F" ) - # print(f'size_P = \n{P.shape}') - # print(f'P = \n{P}') return np.reshape(P, newshape=(-1, ncolFirst), order="F") From f6ad68f13b225d9777f516a2476d7a18ce9e03ca Mon Sep 17 00:00:00 2001 From: Nick Johnson <24689722+ntjohnson1@users.noreply.github.com> Date: Sat, 3 Jun 2023 11:50:09 -0400 Subject: [PATCH 2/3] khatrirao: Fix ambiguity and add typing * Simplify interface to only take repeated matrices * Update all uses * Add test to clearly highlight users of the old interface --- pyttb/cp_apr.py | 12 ++++++---- pyttb/khatrirao.py | 52 ++++++++++++++++++----------------------- pyttb/ktensor.py | 2 +- pyttb/sptensor.py | 4 ++-- pyttb/tensor.py | 10 ++++---- tests/test_khatrirao.py | 19 ++++++++------- 6 files changed, 50 insertions(+), 49 deletions(-) diff --git a/pyttb/cp_apr.py b/pyttb/cp_apr.py index 89020a0c..e9afa110 100644 --- a/pyttb/cp_apr.py +++ b/pyttb/cp_apr.py @@ -1171,8 +1171,10 @@ def tt_calcpi_prowsubprob( Pi *= Model[i][Data.subs[sparse_indices, i], :] else: Pi = ttb.khatrirao( - Model.factor_matrices[:factorIndex] - + Model.factor_matrices[factorIndex + 1 : ndims + 1], + *( + Model.factor_matrices[:factorIndex] + + Model.factor_matrices[factorIndex + 1 : ndims + 1] + ), reverse=True, ) @@ -1660,8 +1662,10 @@ def calculatePi(Data, Model, rank, factorIndex, ndims): Pi *= Model[i][Data.subs[:, i], :] else: Pi = ttb.khatrirao( - Model.factor_matrices[:factorIndex] - + Model.factor_matrices[factorIndex + 1 :], + *( + Model.factor_matrices[:factorIndex] + + Model.factor_matrices[factorIndex + 1 :] + ), reverse=True, ) diff --git a/pyttb/khatrirao.py b/pyttb/khatrirao.py index ab6cd631..225c7f61 100644 --- a/pyttb/khatrirao.py +++ b/pyttb/khatrirao.py @@ -5,7 +5,7 @@ import numpy as np -def khatrirao(*listOfMatrices, reverse=False): +def khatrirao(*matrices: np.ndarray, reverse: bool = False) -> np.ndarray: """ KHATRIRAO Khatri-Rao product of matrices. @@ -16,50 +16,44 @@ def khatrirao(*listOfMatrices, reverse=False): Parameters ---------- - Matrices: [:class:`numpy.ndarray`] or :class:`numpy.ndarray`,:class:`numpy.ndarray`... + matrices: Collection of matrices to take the product of reverse: bool Set to true to calculate product in reverse - Returns - ------- - product: float - Examples -------- >>> A = np.random.normal(size=(5,2)) >>> B = np.random.normal(size=(5,2)) >>> _ = khatrirao(A,B) #<-- Khatri-Rao of A and B >>> _ = khatrirao(B,A,reverse=True) #<-- same thing as above - >>> _ = khatrirao([A,A,B]) #<-- passing a list - >>> _ = khatrirao([B,A,A],reverse = True) #<-- same as above + >>> _ = khatrirao(A,A,B) #<-- passing multiple items + >>> _ = khatrirao(B,A,A,reverse = True) #<-- same as above + >>> _ = khatrirao(*[A,A,B]) #<-- passing a list via unpacking items """ # Determine if list of matrices of multiple matrix arguments - if isinstance(listOfMatrices[0], list): - if len(listOfMatrices) == 1: - listOfMatrices = listOfMatrices[0] - else: - assert ( - False - ), "Khatri Rao Acts on multiple Array arguments or a list of Arrays" + if len(matrices) == 1 and isinstance(matrices[0], list): + raise ValueError( + "Khatrirao interface has changed. Instead of `khatrirao([matrix_a, matrix_b])` " + "please update to use argument unpacking `khatrirao(*[matrix_a, matrix_b])`. " + "This reduces ambiguity in usage moving forward. " + ) + + if not isinstance(reverse, bool): + raise ValueError(f"Expected a bool for reverse but received {reverse}") # Error checking on input and set matrix order if reverse == True: - listOfMatrices = list(reversed(listOfMatrices)) - if not all(len(matrix.shape) == 2 for matrix in listOfMatrices): + matrices = tuple(reversed(matrices)) + if not all(len(matrix.shape) == 2 for matrix in matrices): assert False, "Each argument must be a matrix" - ncolFirst = listOfMatrices[0].shape[1] - if not all(matrix.shape[1] == ncolFirst for matrix in listOfMatrices): + ncolFirst = matrices[0].shape[1] + if not all(matrix.shape[1] == ncolFirst for matrix in matrices): assert False, "All matrices must have the same number of columns." # Computation - P = listOfMatrices[0] - if ncolFirst == 1: - for i in listOfMatrices[1:]: - P = i * np.reshape(P, newshape=(ncolFirst, -1), order="F") - else: - for i in listOfMatrices[1:]: - P = np.reshape(i, newshape=(-1, 1, ncolFirst)) * np.reshape( - P, newshape=(1, -1, ncolFirst), order="F" - ) - + P = matrices[0] + for i in matrices[1:]: + P = np.reshape(i, newshape=(-1, 1, ncolFirst)) * np.reshape( + P, newshape=(1, -1, ncolFirst), order="F" + ) return np.reshape(P, newshape=(-1, ncolFirst), order="F") diff --git a/pyttb/ktensor.py b/pyttb/ktensor.py index 58b068d3..95e82293 100644 --- a/pyttb/ktensor.py +++ b/pyttb/ktensor.py @@ -990,7 +990,7 @@ def full(self): [63. 85.]] """ - data = self.weights @ ttb.khatrirao(self.factor_matrices, reverse=True).T + data = self.weights @ ttb.khatrirao(*self.factor_matrices, reverse=True).T return ttb.tensor.from_data(data, self.shape) def innerprod(self, other): diff --git a/pyttb/sptensor.py b/pyttb/sptensor.py index 9aeb83a5..59bc60d9 100644 --- a/pyttb/sptensor.py +++ b/pyttb/sptensor.py @@ -348,7 +348,7 @@ def allsubs(self) -> np.ndarray: for n in range(0, self.ndims): i = o.copy() i[n] = np.expand_dims(np.arange(0, self.shape[n]), axis=1) - s[:, n] = np.squeeze(ttb.khatrirao(i)) + s[:, n] = np.squeeze(ttb.khatrirao(*i)) return s.astype(int) @@ -1723,7 +1723,7 @@ def _set_subtensor(self, key, value): i[n] = np.array(keyCopy[n])[:, None] else: i[n] = np.array(keyCopy[n], ndmin=2) - addsubs[:, n] = ttb.khatrirao(i).transpose()[:] + addsubs[:, n] = ttb.khatrirao(*i).transpose()[:] if self.subs.size > 0: # Replace existing values diff --git a/pyttb/tensor.py b/pyttb/tensor.py index b415824d..2658afd2 100644 --- a/pyttb/tensor.py +++ b/pyttb/tensor.py @@ -702,16 +702,18 @@ def mttkrp(self, U: Union[ttb.ktensor, List[np.ndarray]], n: int) -> np.ndarray: szn = self.shape[n] if n == 0: - Ur = ttb.khatrirao(U[1 : self.ndims], reverse=True) + Ur = ttb.khatrirao(*U[1 : self.ndims], reverse=True) Y = np.reshape(self.data, (szn, szr), order="F") return Y @ Ur if n == self.ndims - 1: # pylint: disable=no-else-return - Ul = ttb.khatrirao(U[0 : self.ndims - 1], reverse=True) + Ul = ttb.khatrirao(*U[0 : self.ndims - 1], reverse=True) Y = np.reshape(self.data, (szl, szn), order="F") return Y.T @ Ul else: - Ul = ttb.khatrirao(U[n + 1 :], reverse=True) - Ur = np.reshape(ttb.khatrirao(U[0:n], reverse=True), (szl, 1, R), order="F") + Ul = ttb.khatrirao(*U[n + 1 :], reverse=True) + Ur = np.reshape( + ttb.khatrirao(*U[0:n], reverse=True), (szl, 1, R), order="F" + ) Y = np.reshape(self.data, (-1, szr), order="F") Y = Y @ Ul Y = np.reshape(Y, (szl, szn, R), order="F") diff --git a/tests/test_khatrirao.py b/tests/test_khatrirao.py index 9dd00ceb..47bdbd37 100644 --- a/tests/test_khatrirao.py +++ b/tests/test_khatrirao.py @@ -24,8 +24,8 @@ def test_khatrirao(): [64, 125, 216], ] ) - assert (ttb.khatrirao([A, A, A]) == answer).all() - assert (ttb.khatrirao([A, A, A], reverse=True) == answer).all() + assert (ttb.khatrirao(*[A, A, A]) == answer).all() + assert (ttb.khatrirao(*[A, A, A], reverse=True) == answer).all() assert (ttb.khatrirao(A, A, A) == answer).all() # Test case where inputs are column vectors @@ -40,15 +40,9 @@ def test_khatrirao(): a_2[3, 0] * np.ones((16, 1)), ) ) - assert (ttb.khatrirao([a_2, a_1, a_1]) == result).all() + assert (ttb.khatrirao(*[a_2, a_1, a_1]) == result).all() assert (ttb.khatrirao(a_2, a_1, a_1) == result).all() - with pytest.raises(AssertionError) as excinfo: - ttb.khatrirao([a_2, a_1, a_1], a_2) - assert "Khatri Rao Acts on multiple Array arguments or a list of Arrays" in str( - excinfo - ) - with pytest.raises(AssertionError) as excinfo: ttb.khatrirao(a_2, a_1, np.ones((2, 2, 2))) assert "Each argument must be a matrix" in str(excinfo) @@ -56,3 +50,10 @@ def test_khatrirao(): with pytest.raises(AssertionError) as excinfo: ttb.khatrirao(a_2, a_1, a_3) assert "All matrices must have the same number of columns." in str(excinfo) + + # Check old interface error + with pytest.raises(ValueError): + ttb.khatrirao([a_1, a_1, a_1]) + + with pytest.raises(ValueError): + ttb.khatrirao(a_1, a_1, reverse="cat") From 563acea2d69f017bfb7d41b08acdff18b7c1d26c Mon Sep 17 00:00:00 2001 From: Nick Johnson <24689722+ntjohnson1@users.noreply.github.com> Date: Sat, 3 Jun 2023 12:01:15 -0400 Subject: [PATCH 3/3] khatrirao: Add pylint and enforcement --- pyttb/khatrirao.py | 11 ++++++----- tests/test_package.py | 1 + 2 files changed, 7 insertions(+), 5 deletions(-) diff --git a/pyttb/khatrirao.py b/pyttb/khatrirao.py index 225c7f61..083b1814 100644 --- a/pyttb/khatrirao.py +++ b/pyttb/khatrirao.py @@ -1,7 +1,7 @@ # Copyright 2022 National Technology & Engineering Solutions of Sandia, # LLC (NTESS). Under the terms of Contract DE-NA0003525 with NTESS, the # U.S. Government retains certain rights in this software. - +"""Khatri-Rao Product Implementation""" import numpy as np @@ -32,16 +32,17 @@ def khatrirao(*matrices: np.ndarray, reverse: bool = False) -> np.ndarray: # Determine if list of matrices of multiple matrix arguments if len(matrices) == 1 and isinstance(matrices[0], list): raise ValueError( - "Khatrirao interface has changed. Instead of `khatrirao([matrix_a, matrix_b])` " - "please update to use argument unpacking `khatrirao(*[matrix_a, matrix_b])`. " - "This reduces ambiguity in usage moving forward. " + "Khatrirao interface has changed. Instead of " + " `khatrirao([matrix_a, matrix_b])` please update to use argument " + "unpacking `khatrirao(*[matrix_a, matrix_b])`. This reduces ambiguity " + "in usage moving forward. " ) if not isinstance(reverse, bool): raise ValueError(f"Expected a bool for reverse but received {reverse}") # Error checking on input and set matrix order - if reverse == True: + if reverse is True: matrices = tuple(reversed(matrices)) if not all(len(matrix.shape) == 2 for matrix in matrices): assert False, "Each argument must be a matrix" diff --git a/tests/test_package.py b/tests/test_package.py index fe87ea42..5011603a 100644 --- a/tests/test_package.py +++ b/tests/test_package.py @@ -27,6 +27,7 @@ def test_linting(): os.path.join(os.path.dirname(ttb.__file__), f"{ttb.tensor.__name__}.py"), os.path.join(os.path.dirname(ttb.__file__), f"{ttb.sptensor.__name__}.py"), ttb.pyttb_utils.__file__, + os.path.join(os.path.dirname(ttb.__file__), f"{ttb.khatrirao.__name__}.py"), ] # TODO pylint fails to import pyttb in tests # add mypy check