44
55
66class SNMFOptimizer :
7- def __init__ (self , MM , Y0 = None , X0 = None , A = None , rho = 1e12 , eta = 610 , max_iter = 500 , tol = 5e-7 , components = None ):
8- print ("Initializing SNMF Optimizer" )
7+ """A implementation of stretched NMF (sNMF), including sparse stretched NMF.
8+
9+ Instantiating the SNMFOptimizer class runs all the analysis immediately.
10+ The results matrices can then be accessed as instance attributes
11+ of the class (X, Y, and A).
12+
13+ For more information on sNMF, please reference:
14+ Gu, R., Rakita, Y., Lan, L. et al. Stretched non-negative matrix factorization.
15+ npj Comput Mater 10, 193 (2024). https://doi.org/10.1038/s41524-024-01377-5
16+
17+ Attributes
18+ ----------
19+ MM : ndarray
20+ The original, unmodified data to be decomposed and later, compared against.
21+ Shape is (length_of_signal, number_of_conditions).
22+ Y : ndarray
23+ The best guess (or while running, the current guess) for the stretching
24+ factor matrix.
25+ X : ndarray
26+ The best guess (or while running, the current guess) for the matrix of
27+ component intensities.
28+ A : ndarray
29+ The best guess (or while running, the current guess) for the matrix of
30+ component weights.
31+ rho : float
32+ The stretching factor that influences the decomposition. Zero corresponds to no
33+ stretching present. Relatively insensitive and typically adjusted in powers of 10.
34+ eta : float
35+ The sparsity factor that influences the decomposition. Should be set to zero for
36+ non-sparse data such as PDF. Can be used to improve results for sparse data such
37+ as XRD, but due to instability, should be used only after first selecting the
38+ best value for rho. Suggested adjustment is by powers of 2.
39+ max_iter : int
40+ The maximum number of times to update each of A, X, and Y before stopping
41+ the optimization.
42+ tol : float
43+ The convergence threshold. This is the minimum fractional improvement in the
44+ objective function to allow without terminating the optimization. Note that
45+ a minimum of 20 updates are run before this parameter is checked.
46+ n_components : int
47+ The number of components to extract from MM. Must be provided when and only when
48+ Y0 is not provided.
49+ random_state : int
50+ The seed for the initial guesses at the matrices (A, X, and Y) created by
51+ the decomposition.
52+ num_updates : int
53+ The total number of times that any of (A, X, and Y) have had their values changed.
54+ If not terminated by other means, this value is used to stop when reaching max_iter.
55+ objective_function: float
56+ The value corresponding to the minimization of the difference between the MM and the
57+ products of A, X, and Y. For full details see the sNMF paper. Smaller corresponds to
58+ better agreement and is desirable.
59+ objective_difference : float
60+ The change in the objective function value since the last update. A negative value
61+ means that the result improved.
62+ """
63+
64+ def __init__ (
65+ self ,
66+ MM ,
67+ Y0 = None ,
68+ X0 = None ,
69+ A0 = None ,
70+ rho = 1e12 ,
71+ eta = 610 ,
72+ max_iter = 500 ,
73+ tol = 5e-7 ,
74+ n_components = None ,
75+ random_state = None ,
76+ ):
77+ """Initialize an instance of SNMF and run the optimization.
78+
79+ Parameters
80+ ----------
81+ MM : ndarray
82+ The data to be decomposed. Shape is (length_of_signal, number_of_conditions).
83+ Y0 : ndarray
84+ The initial guesses for the component weights at each stretching condition.
85+ Shape is (number_of_components, number_of_conditions) Must provide exactly one
86+ of this or n_components.
87+ X0 : ndarray
88+ The initial guesses for the intensities of each component per
89+ row/sample/angle. Shape is (length_of_signal, number_of_components).
90+ A0 : ndarray
91+ The initial guesses for the stretching factor for each component, at each
92+ condition. Shape is (number_of_components, number_of_conditions).
93+ rho : float
94+ The stretching factor that influences the decomposition. Zero corresponds to no
95+ stretching present. Relatively insensitive and typically adjusted in powers of 10.
96+ eta : float
97+ The sparsity factor that influences the decomposition. Should be set to zero for
98+ non-sparse data such as PDF. Can be used to improve results for sparse data such
99+ as XRD, but due to instability, should be used only after first selecting the
100+ best value for rho. Suggested adjustment is by powers of 2.
101+ max_iter : int
102+ The maximum number of times to update each of A, X, and Y before stopping
103+ the optimization.
104+ tol : float
105+ The convergence threshold. This is the minimum fractional improvement in the
106+ objective function to allow without terminating the optimization. Note that
107+ a minimum of 20 updates are run before this parameter is checked.
108+ n_components : int
109+ The number of components to extract from MM. Must be provided when and only when
110+ Y0 is not provided.
111+ random_state : int
112+ The seed for the initial guesses at the matrices (A, X, and Y) created by
113+ the decomposition.
114+ """
115+
9116 self .MM = MM
10- self .X0 = X0
11- self .Y0 = Y0
12- self .A = A
13117 self .rho = rho
14118 self .eta = eta
15119 # Capture matrix dimensions
16- self .N , self .M = MM .shape
120+ self ._N , self ._M = MM .shape
17121 self .num_updates = 0
122+ self ._rng = np .random .default_rng (random_state )
18123
124+ # Enforce exclusive specification of n_components or Y0
125+ if (n_components is None ) == (Y0 is not None ):
126+ raise ValueError ("Must provide exactly one of Y0 or n_components, but not both." )
127+
128+ # Initialize Y0 and determine number of components
19129 if Y0 is None :
20- if components is None :
21- raise ValueError ("Must provide either Y0 or a number of components." )
22- else :
23- self .K = components
24- self .Y0 = np .random .beta (a = 2.5 , b = 1.5 , size = (self .K , self .M )) # This is untested
130+ self ._K = n_components
131+ self .Y = self ._rng .beta (a = 2.5 , b = 1.5 , size = (self ._K , self ._M ))
25132 else :
26- self .K = Y0 .shape [0 ]
133+ self ._K = Y0 .shape [0 ]
134+ self .Y = Y0
27135
28- # Initialize A, X0 if not provided
136+ # Initialize A if not provided
29137 if self .A is None :
30- self .A = np .ones ((self .K , self .M )) + np .random .randn (self .K , self .M ) * 1e-3 # Small perturbation
31- if self .X0 is None :
32- self .X0 = np .random .rand (self .N , self .K ) # Ensures values in [0,1]
138+ self .A = np .ones ((self ._K , self ._M )) + self ._rng .normal (0 , 1e-3 , size = (self ._K , self ._M ))
139+ else :
140+ self .A = A0
141+
142+ # Initialize X0 if not provided
143+ if self .X is None :
144+ self .X = self ._rng .random ((self ._N , self ._K ))
145+ else :
146+ self .X = X0
33147
34- # Initialize solution matrices to be iterated on
35- self .X = np .maximum (0 , self .X0 )
36- self .Y = np .maximum (0 , self .Y0 )
148+ # Enforce non-negativity
149+ self .X = np .maximum (0 , self .X )
150+ self .Y = np .maximum (0 , self .Y )
37151
38152 # Second-order spline: Tridiagonal (-2 on diagonal, 1 on sub/superdiagonals)
39- self .P = 0.25 * diags ([1 , - 2 , 1 ], offsets = [0 , 1 , 2 ], shape = (self .M - 2 , self .M ))
153+ self .P = 0.25 * diags ([1 , - 2 , 1 ], offsets = [0 , 1 , 2 ], shape = (self ._M - 2 , self ._M ))
40154 self .PP = self .P .T @ self .P
41155
42156 # Set up residual matrix, objective function, and history
43157 self .R = self .get_residual_matrix ()
44158 self .objective_function = self .get_objective_function ()
45159 self .objective_difference = None
46- self .objective_history = [self .objective_function ]
160+ self ._objective_history = [self .objective_function ]
47161
48162 # Set up tracking variables for updateX()
49- self .preX = None
50- self .GraX = np .zeros_like (self .X ) # Gradient of X (zeros for now)
51- self .preGraX = np .zeros_like (self .X ) # Previous gradient of X (zeros for now)
163+ self ._preX = None
164+ self ._GraX = np .zeros_like (self .X ) # Gradient of X (zeros for now)
165+ self ._preGraX = np .zeros_like (self .X ) # Previous gradient of X (zeros for now)
52166
53167 regularization_term = 0.5 * rho * np .linalg .norm (self .P @ self .A .T , "fro" ) ** 2
54168 sparsity_term = eta * np .sum (np .sqrt (self .X )) # Square root penalty
@@ -83,53 +197,53 @@ def __init__(self, MM, Y0=None, X0=None, A=None, rho=1e12, eta=610, max_iter=500
83197 # loop to normalize X
84198 # effectively just re-running class with non-normalized X, normalized Y/A as inputs, then only update X
85199 # reset difference trackers and initialize
86- self .preX = None
87- self .GraX = np .zeros_like (self .X ) # Gradient of X (zeros for now)
88- self .preGraX = np .zeros_like (self .X ) # Previous gradient of X (zeros for now)
200+ self ._preX = None
201+ self ._GraX = np .zeros_like (self .X ) # Gradient of X (zeros for now)
202+ self ._preGraX = np .zeros_like (self .X ) # Previous gradient of X (zeros for now)
89203 self .R = self .get_residual_matrix ()
90204 self .objective_function = self .get_objective_function ()
91205 self .objective_difference = None
92- self .objective_history = [self .objective_function ]
206+ self ._objective_history = [self .objective_function ]
93207 for norm_iter in range (100 ):
94208 self .updateX ()
95209 self .R = self .get_residual_matrix ()
96210 self .objective_function = self .get_objective_function ()
97211 print (f"Objective function after normX: { self .objective_function :.5e} " )
98- self .objective_history .append (self .objective_function )
99- self .objective_difference = self .objective_history [- 2 ] - self .objective_history [- 1 ]
212+ self ._objective_history .append (self .objective_function )
213+ self .objective_difference = self ._objective_history [- 2 ] - self ._objective_history [- 1 ]
100214 if self .objective_difference < self .objective_function * tol and norm_iter >= 20 :
101215 break
102216 # end of normalization (and program)
103217 # note that objective function may not fully recover after normalization, this is okay
104218 print ("Finished optimization." )
105219
106220 def optimize_loop (self ):
107- self .preGraX = self .GraX .copy ()
221+ self ._preGraX = self ._GraX .copy ()
108222 self .updateX ()
109223 self .num_updates += 1
110224 self .R = self .get_residual_matrix ()
111225 self .objective_function = self .get_objective_function ()
112226 print (f"Objective function after updateX: { self .objective_function :.5e} " )
113- self .objective_history .append (self .objective_function )
227+ self ._objective_history .append (self .objective_function )
114228 if self .objective_difference is None :
115- self .objective_difference = self .objective_history [- 1 ] - self .objective_function
229+ self .objective_difference = self ._objective_history [- 1 ] - self .objective_function
116230
117231 # Now we update Y
118232 self .updateY2 ()
119233 self .num_updates += 1
120234 self .R = self .get_residual_matrix ()
121235 self .objective_function = self .get_objective_function ()
122236 print (f"Objective function after updateY2: { self .objective_function :.5e} " )
123- self .objective_history .append (self .objective_function )
237+ self ._objective_history .append (self .objective_function )
124238
125239 self .updateA2 ()
126240
127241 self .num_updates += 1
128242 self .R = self .get_residual_matrix ()
129243 self .objective_function = self .get_objective_function ()
130244 print (f"Objective function after updateA2: { self .objective_function :.5e} " )
131- self .objective_history .append (self .objective_function )
132- self .objective_difference = self .objective_history [- 2 ] - self .objective_history [- 1 ]
245+ self ._objective_history .append (self .objective_function )
246+ self .objective_difference = self ._objective_history [- 2 ] - self ._objective_history [- 1 ]
133247
134248 def apply_interpolation (self , a , x , return_derivatives = False ):
135249 """
@@ -401,36 +515,38 @@ def updateX(self):
401515 # Compute `AX` using the interpolation function
402516 AX , _ , _ = self .apply_interpolation_matrix () # Skip the other two outputs
403517 # Compute RA and RR
404- intermediate_RA = AX .flatten (order = "F" ).reshape ((self .N * self .M , self .K ), order = "F" )
405- RA = intermediate_RA .sum (axis = 1 ).reshape ((self .N , self .M ), order = "F" )
518+ intermediate_RA = AX .flatten (order = "F" ).reshape ((self ._N * self ._M , self ._K ), order = "F" )
519+ RA = intermediate_RA .sum (axis = 1 ).reshape ((self ._N , self ._M ), order = "F" )
406520 RR = RA - self .MM
407521 # Compute gradient `GraX`
408- self .GraX = self .apply_transformation_matrix (R = RR ).toarray () # toarray equivalent of full, make non-sparse
522+ self ._GraX = self .apply_transformation_matrix (
523+ R = RR
524+ ).toarray () # toarray equivalent of full, make non-sparse
409525
410526 # Compute initial step size `L0`
411527 L0 = np .linalg .eigvalsh (self .Y .T @ self .Y ).max () * np .max ([self .A .max (), 1 / self .A .min ()])
412528 # Compute adaptive step size `L`
413- if self .preX is None :
529+ if self ._preX is None :
414530 L = L0
415531 else :
416- num = np .sum ((self .GraX - self .preGraX ) * (self .X - self .preX )) # Element-wise multiplication
417- denom = np .linalg .norm (self .X - self .preX , "fro" ) ** 2 # Frobenius norm squared
532+ num = np .sum ((self ._GraX - self ._preGraX ) * (self .X - self ._preX )) # Element-wise multiplication
533+ denom = np .linalg .norm (self .X - self ._preX , "fro" ) ** 2 # Frobenius norm squared
418534 L = num / denom if denom > 0 else L0
419535 if L <= 0 :
420536 L = L0
421537
422538 # Store our old X before updating because it is used in step selection
423- self .preX = self .X .copy ()
539+ self ._preX = self .X .copy ()
424540
425541 while True : # iterate updating X
426- x_step = self .preX - self .GraX / L
542+ x_step = self ._preX - self ._GraX / L
427543 # Solve x^3 + p*x + q = 0 for the largest real root
428544 self .X = np .square (cubic_largest_real_root (- x_step , self .eta / (2 * L )))
429545 # Mask values that should be set to zero
430546 mask = self .X ** 2 * L / 2 - L * self .X * x_step + self .eta * np .sqrt (self .X ) < 0
431547 self .X = mask * self .X
432548
433- objective_improvement = self .objective_history [- 1 ] - self .get_objective_function (
549+ objective_improvement = self ._objective_history [- 1 ] - self .get_objective_function (
434550 R = self .get_residual_matrix ()
435551 )
436552
@@ -447,9 +563,9 @@ def updateY2(self):
447563 Updates Y using matrix operations, solving a quadratic program via `solve_mkr_box`.
448564 """
449565
450- K = self .K
451- N = self .N
452- M = self .M
566+ K = self ._K
567+ N = self ._N
568+ M = self ._M
453569
454570 for m in range (M ):
455571 T = np .zeros ((N , K )) # Initialize T as an (N, K) zero matrix
@@ -476,9 +592,9 @@ def regularize_function(self, A=None):
476592 if A is None :
477593 A = self .A
478594
479- K = self .K
480- M = self .M
481- N = self .N
595+ K = self ._K
596+ M = self ._M
597+ N = self ._N
482598
483599 # Compute interpolated matrices
484600 AX , TX , HX = self .apply_interpolation_matrix (A = A , return_derivatives = True )
0 commit comments