11# Copyright 2022 National Technology & Engineering Solutions of Sandia,
22# LLC (NTESS). Under the terms of Contract DE-NA0003525 with NTESS, the
33# U.S. Government retains certain rights in this software.
4-
4+ """Khatri-Rao Product Implementation"""
55import numpy as np
66
77
8- def khatrirao (* listOfMatrices , reverse = False ):
8+ def khatrirao (* matrices : np . ndarray , reverse : bool = False ) -> np . ndarray :
99 """
1010 KHATRIRAO Khatri-Rao product of matrices.
1111
@@ -16,69 +16,45 @@ def khatrirao(*listOfMatrices, reverse=False):
1616
1717 Parameters
1818 ----------
19- Matrices: [:class:`numpy.ndarray`] or :class:`numpy.ndarray`,:class:`numpy.ndarray`...
19+ matrices: Collection of matrices to take the product of
2020 reverse: bool Set to true to calculate product in reverse
2121
22- Returns
23- -------
24- product: float
25-
2622 Examples
2723 --------
2824 >>> A = np.random.normal(size=(5,2))
2925 >>> B = np.random.normal(size=(5,2))
3026 >>> _ = khatrirao(A,B) #<-- Khatri-Rao of A and B
3127 >>> _ = khatrirao(B,A,reverse=True) #<-- same thing as above
32- >>> _ = khatrirao([A,A,B]) #<-- passing a list
33- >>> _ = khatrirao([B,A,A],reverse = True) #<-- same as above
28+ >>> _ = khatrirao(A,A,B) #<-- passing multiple items
29+ >>> _ = khatrirao(B,A,A,reverse = True) #<-- same as above
30+ >>> _ = khatrirao(*[A,A,B]) #<-- passing a list via unpacking items
3431 """
3532 # Determine if list of matrices of multiple matrix arguments
36- if isinstance (listOfMatrices [0 ], list ):
37- if len (listOfMatrices ) == 1 :
38- listOfMatrices = listOfMatrices [0 ]
39- else :
40- assert (
41- False
42- ), "Khatri Rao Acts on multiple Array arguments or a list of Arrays"
33+ if len (matrices ) == 1 and isinstance (matrices [0 ], list ):
34+ raise ValueError (
35+ "Khatrirao interface has changed. Instead of "
36+ " `khatrirao([matrix_a, matrix_b])` please update to use argument "
37+ "unpacking `khatrirao(*[matrix_a, matrix_b])`. This reduces ambiguity "
38+ "in usage moving forward. "
39+ )
40+
41+ if not isinstance (reverse , bool ):
42+ raise ValueError (f"Expected a bool for reverse but received { reverse } " )
4343
4444 # Error checking on input and set matrix order
45- if reverse == True :
46- listOfMatrices = list (reversed (listOfMatrices ))
47- ndimsA = [len (matrix .shape ) == 2 for matrix in listOfMatrices ]
48- if not np .all (ndimsA ):
45+ if reverse is True :
46+ matrices = tuple (reversed (matrices ))
47+ if not all (len (matrix .shape ) == 2 for matrix in matrices ):
4948 assert False , "Each argument must be a matrix"
5049
51- ncolFirst = listOfMatrices [0 ].shape [1 ]
52- ncols = [matrix .shape [1 ] == ncolFirst for matrix in listOfMatrices ]
53- if not np .all (ncols ):
50+ ncolFirst = matrices [0 ].shape [1 ]
51+ if not all (matrix .shape [1 ] == ncolFirst for matrix in matrices ):
5452 assert False , "All matrices must have the same number of columns."
5553
5654 # Computation
57- # print(f'A =\n {listOfMatrices}')
58- P = listOfMatrices [0 ]
59- # print(f'size_P = \n{P.shape}')
60- # print(f'P = \n{P}')
61- if ncolFirst == 1 :
62- for i in listOfMatrices [1 :]:
63- # print(f'size_Ai = \n{i.shape}')
64- # print(f'size_reshape_Ai = \n{np.reshape(i, newshape=(-1, ncolFirst)).shape}')
65- # print(f'size_P = \n{P.shape}')
66- # print(f'size_reshape_P = \n{np.reshape(P, newshape=(ncolFirst, -1)).shape}')
67- P = np .reshape (i , newshape = (- 1 , ncolFirst )) * np .reshape (
68- P , newshape = (ncolFirst , - 1 ), order = "F"
69- )
70- # print(f'size_P = \n{P.shape}')
71- # print(f'P = \n{P}')
72- else :
73- for i in listOfMatrices [1 :]:
74- # print(f'size_Ai = \n{i.shape}')
75- # print(f'size_reshape_Ai = \n{np.reshape(i, newshape=(-1, 1, ncolFirst)).shape}')
76- # print(f'size_P = \n{P.shape}')
77- # print(f'size_reshape_P = \n{np.reshape(P, newshape=(1, -1, ncolFirst)).shape}')
78- P = np .reshape (i , newshape = (- 1 , 1 , ncolFirst )) * np .reshape (
79- P , newshape = (1 , - 1 , ncolFirst ), order = "F"
80- )
81- # print(f'size_P = \n{P.shape}')
82- # print(f'P = \n{P}')
83-
55+ P = matrices [0 ]
56+ for i in matrices [1 :]:
57+ P = np .reshape (i , newshape = (- 1 , 1 , ncolFirst )) * np .reshape (
58+ P , newshape = (1 , - 1 , ncolFirst ), order = "F"
59+ )
8460 return np .reshape (P , newshape = (- 1 , ncolFirst ), order = "F" )
0 commit comments