Skip to content

Commit d70a102

Browse files
authored
Khatrirao cleanup (#127)
* khatrirao: Clear up some debug and use generators * khatrirao: Fix ambiguity and add typing * Simplify interface to only take repeated matrices * Update all uses * Add test to clearly highlight users of the old interface * khatrirao: Add pylint and enforcement
1 parent 3aa1deb commit d70a102

File tree

7 files changed

+54
-70
lines changed

7 files changed

+54
-70
lines changed

pyttb/cp_apr.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1171,8 +1171,10 @@ def tt_calcpi_prowsubprob(
11711171
Pi *= Model[i][Data.subs[sparse_indices, i], :]
11721172
else:
11731173
Pi = ttb.khatrirao(
1174-
Model.factor_matrices[:factorIndex]
1175-
+ Model.factor_matrices[factorIndex + 1 : ndims + 1],
1174+
*(
1175+
Model.factor_matrices[:factorIndex]
1176+
+ Model.factor_matrices[factorIndex + 1 : ndims + 1]
1177+
),
11761178
reverse=True,
11771179
)
11781180

@@ -1660,8 +1662,10 @@ def calculatePi(Data, Model, rank, factorIndex, ndims):
16601662
Pi *= Model[i][Data.subs[:, i], :]
16611663
else:
16621664
Pi = ttb.khatrirao(
1663-
Model.factor_matrices[:factorIndex]
1664-
+ Model.factor_matrices[factorIndex + 1 :],
1665+
*(
1666+
Model.factor_matrices[:factorIndex]
1667+
+ Model.factor_matrices[factorIndex + 1 :]
1668+
),
16651669
reverse=True,
16661670
)
16671671

pyttb/khatrirao.py

Lines changed: 26 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,11 @@
11
# Copyright 2022 National Technology & Engineering Solutions of Sandia,
22
# LLC (NTESS). Under the terms of Contract DE-NA0003525 with NTESS, the
33
# U.S. Government retains certain rights in this software.
4-
4+
"""Khatri-Rao Product Implementation"""
55
import numpy as np
66

77

8-
def khatrirao(*listOfMatrices, reverse=False):
8+
def khatrirao(*matrices: np.ndarray, reverse: bool = False) -> np.ndarray:
99
"""
1010
KHATRIRAO Khatri-Rao product of matrices.
1111
@@ -16,69 +16,45 @@ def khatrirao(*listOfMatrices, reverse=False):
1616
1717
Parameters
1818
----------
19-
Matrices: [:class:`numpy.ndarray`] or :class:`numpy.ndarray`,:class:`numpy.ndarray`...
19+
matrices: Collection of matrices to take the product of
2020
reverse: bool Set to true to calculate product in reverse
2121
22-
Returns
23-
-------
24-
product: float
25-
2622
Examples
2723
--------
2824
>>> A = np.random.normal(size=(5,2))
2925
>>> B = np.random.normal(size=(5,2))
3026
>>> _ = khatrirao(A,B) #<-- Khatri-Rao of A and B
3127
>>> _ = khatrirao(B,A,reverse=True) #<-- same thing as above
32-
>>> _ = khatrirao([A,A,B]) #<-- passing a list
33-
>>> _ = khatrirao([B,A,A],reverse = True) #<-- same as above
28+
>>> _ = khatrirao(A,A,B) #<-- passing multiple items
29+
>>> _ = khatrirao(B,A,A,reverse = True) #<-- same as above
30+
>>> _ = khatrirao(*[A,A,B]) #<-- passing a list via unpacking items
3431
"""
3532
# Determine if list of matrices of multiple matrix arguments
36-
if isinstance(listOfMatrices[0], list):
37-
if len(listOfMatrices) == 1:
38-
listOfMatrices = listOfMatrices[0]
39-
else:
40-
assert (
41-
False
42-
), "Khatri Rao Acts on multiple Array arguments or a list of Arrays"
33+
if len(matrices) == 1 and isinstance(matrices[0], list):
34+
raise ValueError(
35+
"Khatrirao interface has changed. Instead of "
36+
" `khatrirao([matrix_a, matrix_b])` please update to use argument "
37+
"unpacking `khatrirao(*[matrix_a, matrix_b])`. This reduces ambiguity "
38+
"in usage moving forward. "
39+
)
40+
41+
if not isinstance(reverse, bool):
42+
raise ValueError(f"Expected a bool for reverse but received {reverse}")
4343

4444
# Error checking on input and set matrix order
45-
if reverse == True:
46-
listOfMatrices = list(reversed(listOfMatrices))
47-
ndimsA = [len(matrix.shape) == 2 for matrix in listOfMatrices]
48-
if not np.all(ndimsA):
45+
if reverse is True:
46+
matrices = tuple(reversed(matrices))
47+
if not all(len(matrix.shape) == 2 for matrix in matrices):
4948
assert False, "Each argument must be a matrix"
5049

51-
ncolFirst = listOfMatrices[0].shape[1]
52-
ncols = [matrix.shape[1] == ncolFirst for matrix in listOfMatrices]
53-
if not np.all(ncols):
50+
ncolFirst = matrices[0].shape[1]
51+
if not all(matrix.shape[1] == ncolFirst for matrix in matrices):
5452
assert False, "All matrices must have the same number of columns."
5553

5654
# Computation
57-
# print(f'A =\n {listOfMatrices}')
58-
P = listOfMatrices[0]
59-
# print(f'size_P = \n{P.shape}')
60-
# print(f'P = \n{P}')
61-
if ncolFirst == 1:
62-
for i in listOfMatrices[1:]:
63-
# print(f'size_Ai = \n{i.shape}')
64-
# print(f'size_reshape_Ai = \n{np.reshape(i, newshape=(-1, ncolFirst)).shape}')
65-
# print(f'size_P = \n{P.shape}')
66-
# print(f'size_reshape_P = \n{np.reshape(P, newshape=(ncolFirst, -1)).shape}')
67-
P = np.reshape(i, newshape=(-1, ncolFirst)) * np.reshape(
68-
P, newshape=(ncolFirst, -1), order="F"
69-
)
70-
# print(f'size_P = \n{P.shape}')
71-
# print(f'P = \n{P}')
72-
else:
73-
for i in listOfMatrices[1:]:
74-
# print(f'size_Ai = \n{i.shape}')
75-
# print(f'size_reshape_Ai = \n{np.reshape(i, newshape=(-1, 1, ncolFirst)).shape}')
76-
# print(f'size_P = \n{P.shape}')
77-
# print(f'size_reshape_P = \n{np.reshape(P, newshape=(1, -1, ncolFirst)).shape}')
78-
P = np.reshape(i, newshape=(-1, 1, ncolFirst)) * np.reshape(
79-
P, newshape=(1, -1, ncolFirst), order="F"
80-
)
81-
# print(f'size_P = \n{P.shape}')
82-
# print(f'P = \n{P}')
83-
55+
P = matrices[0]
56+
for i in matrices[1:]:
57+
P = np.reshape(i, newshape=(-1, 1, ncolFirst)) * np.reshape(
58+
P, newshape=(1, -1, ncolFirst), order="F"
59+
)
8460
return np.reshape(P, newshape=(-1, ncolFirst), order="F")

pyttb/ktensor.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -990,7 +990,7 @@ def full(self):
990990
[63. 85.]]
991991
<BLANKLINE>
992992
"""
993-
data = self.weights @ ttb.khatrirao(self.factor_matrices, reverse=True).T
993+
data = self.weights @ ttb.khatrirao(*self.factor_matrices, reverse=True).T
994994
return ttb.tensor.from_data(data, self.shape)
995995

996996
def innerprod(self, other):

pyttb/sptensor.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -348,7 +348,7 @@ def allsubs(self) -> np.ndarray:
348348
for n in range(0, self.ndims):
349349
i = o.copy()
350350
i[n] = np.expand_dims(np.arange(0, self.shape[n]), axis=1)
351-
s[:, n] = np.squeeze(ttb.khatrirao(i))
351+
s[:, n] = np.squeeze(ttb.khatrirao(*i))
352352

353353
return s.astype(int)
354354

@@ -1723,7 +1723,7 @@ def _set_subtensor(self, key, value):
17231723
i[n] = np.array(keyCopy[n])[:, None]
17241724
else:
17251725
i[n] = np.array(keyCopy[n], ndmin=2)
1726-
addsubs[:, n] = ttb.khatrirao(i).transpose()[:]
1726+
addsubs[:, n] = ttb.khatrirao(*i).transpose()[:]
17271727

17281728
if self.subs.size > 0:
17291729
# Replace existing values

pyttb/tensor.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -702,16 +702,18 @@ def mttkrp(self, U: Union[ttb.ktensor, List[np.ndarray]], n: int) -> np.ndarray:
702702
szn = self.shape[n]
703703

704704
if n == 0:
705-
Ur = ttb.khatrirao(U[1 : self.ndims], reverse=True)
705+
Ur = ttb.khatrirao(*U[1 : self.ndims], reverse=True)
706706
Y = np.reshape(self.data, (szn, szr), order="F")
707707
return Y @ Ur
708708
if n == self.ndims - 1: # pylint: disable=no-else-return
709-
Ul = ttb.khatrirao(U[0 : self.ndims - 1], reverse=True)
709+
Ul = ttb.khatrirao(*U[0 : self.ndims - 1], reverse=True)
710710
Y = np.reshape(self.data, (szl, szn), order="F")
711711
return Y.T @ Ul
712712
else:
713-
Ul = ttb.khatrirao(U[n + 1 :], reverse=True)
714-
Ur = np.reshape(ttb.khatrirao(U[0:n], reverse=True), (szl, 1, R), order="F")
713+
Ul = ttb.khatrirao(*U[n + 1 :], reverse=True)
714+
Ur = np.reshape(
715+
ttb.khatrirao(*U[0:n], reverse=True), (szl, 1, R), order="F"
716+
)
715717
Y = np.reshape(self.data, (-1, szr), order="F")
716718
Y = Y @ Ul
717719
Y = np.reshape(Y, (szl, szn, R), order="F")

tests/test_khatrirao.py

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -24,8 +24,8 @@ def test_khatrirao():
2424
[64, 125, 216],
2525
]
2626
)
27-
assert (ttb.khatrirao([A, A, A]) == answer).all()
28-
assert (ttb.khatrirao([A, A, A], reverse=True) == answer).all()
27+
assert (ttb.khatrirao(*[A, A, A]) == answer).all()
28+
assert (ttb.khatrirao(*[A, A, A], reverse=True) == answer).all()
2929
assert (ttb.khatrirao(A, A, A) == answer).all()
3030

3131
# Test case where inputs are column vectors
@@ -40,19 +40,20 @@ def test_khatrirao():
4040
a_2[3, 0] * np.ones((16, 1)),
4141
)
4242
)
43-
assert (ttb.khatrirao([a_2, a_1, a_1]) == result).all()
43+
assert (ttb.khatrirao(*[a_2, a_1, a_1]) == result).all()
4444
assert (ttb.khatrirao(a_2, a_1, a_1) == result).all()
4545

46-
with pytest.raises(AssertionError) as excinfo:
47-
ttb.khatrirao([a_2, a_1, a_1], a_2)
48-
assert "Khatri Rao Acts on multiple Array arguments or a list of Arrays" in str(
49-
excinfo
50-
)
51-
5246
with pytest.raises(AssertionError) as excinfo:
5347
ttb.khatrirao(a_2, a_1, np.ones((2, 2, 2)))
5448
assert "Each argument must be a matrix" in str(excinfo)
5549

5650
with pytest.raises(AssertionError) as excinfo:
5751
ttb.khatrirao(a_2, a_1, a_3)
5852
assert "All matrices must have the same number of columns." in str(excinfo)
53+
54+
# Check old interface error
55+
with pytest.raises(ValueError):
56+
ttb.khatrirao([a_1, a_1, a_1])
57+
58+
with pytest.raises(ValueError):
59+
ttb.khatrirao(a_1, a_1, reverse="cat")

tests/test_package.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ def test_linting():
2727
os.path.join(os.path.dirname(ttb.__file__), f"{ttb.tensor.__name__}.py"),
2828
os.path.join(os.path.dirname(ttb.__file__), f"{ttb.sptensor.__name__}.py"),
2929
ttb.pyttb_utils.__file__,
30+
os.path.join(os.path.dirname(ttb.__file__), f"{ttb.khatrirao.__name__}.py"),
3031
]
3132
# TODO pylint fails to import pyttb in tests
3233
# add mypy check

0 commit comments

Comments
 (0)