@@ -113,13 +113,15 @@ def __init__(
113113 The number of components to extract from source_matrix. Must be provided when and only when
114114 init_weights is not provided.
115115 random_state : int Optional Default = None
116- The seed for the initial guesses at the matrices (A, X , and Y ) created by
116+ The seed for the initial guesses at the matrices (stretch, components , and weights ) created by
117117 the decomposition.
118118 """
119119
120120 self .source_matrix = source_matrix
121121 self .rho = rho
122122 self .eta = eta
123+ self .tol = tol
124+ self .max_iter = max_iter
123125 # Capture matrix dimensions
124126 self .signal_length , self .n_signals = source_matrix .shape
125127 self .num_updates = 0
@@ -164,18 +166,19 @@ def __init__(
164166 self ._spline_smooth_operator = 0.25 * diags (
165167 [1 , - 2 , 1 ], offsets = [0 , 1 , 2 ], shape = (self .n_signals - 2 , self .n_signals )
166168 )
167- self ._spline_smooth_penalty = self ._spline_smooth_operator .T @ self ._spline_smooth_operator
168169
169170 # Set up residual matrix, objective function, and history
170171 self .residuals = self .get_residual_matrix ()
171- self ._objective_history = []
172- self .update_objective ()
172+ self .objective_function = self .get_objective_function ()
173+ self .best_objective = self .objective_function
174+ self .best_matrices = [self .components .copy (), self .weights .copy (), self .stretch .copy ()]
173175 self .objective_difference = None
176+ self ._objective_history = [self .objective_function ]
174177
175178 # Set up tracking variables for update_components()
176179 self ._prev_components = None
177- self .grad_components = np .zeros_like (self .components ) # Gradient of X (zeros for now)
178- self ._prev_grad_components = np .zeros_like (self .components ) # Previous gradient of X (zeros for now)
180+ self .grad_components = np .zeros_like (self .components )
181+ self ._prev_grad_components = np .zeros_like (self .components )
179182
180183 regularization_term = 0.5 * rho * np .linalg .norm (self ._spline_smooth_operator @ self .stretch .T , "fro" ) ** 2
181184 sparsity_term = eta * np .sum (np .sqrt (self .components )) # Square root penalty
@@ -265,57 +268,6 @@ def optimize_loop(self):
265268
266269 self .objective_difference = self ._objective_history [- 2 ] - self ._objective_history [- 1 ]
267270
268- def apply_interpolation (self , a , x , return_derivatives = False ):
269- """
270- Applies an interpolation-based transformation to `x` based on scaling `a`.
271- Also can compute first (`d_intr_x`) and second (`dd_intr_x`) derivatives.
272- """
273- x_len = len (x )
274-
275- # Ensure `a` is an array and reshape for broadcasting
276- a = np .atleast_1d (np .asarray (a )) # Ensures a is at least 1D
277-
278- # Compute fractional indices, broadcasting over `a`
279- fractional_indices = np .arange (x_len )[:, None ] / a # Shape (N, M)
280-
281- integer_indices = np .floor (fractional_indices ).astype (int ) # Integer part (still (N, M))
282- valid_mask = integer_indices < (x_len - 1 ) # Ensure indices are within bounds
283-
284- # Apply valid_mask to keep correct indices
285- idx_int = np .where (
286- valid_mask , integer_indices , x_len - 2
287- ) # Prevent out-of-bounds indexing (previously "I")
288- idx_frac = np .where (valid_mask , fractional_indices , integer_indices ) # Keep aligned (previously "i")
289-
290- # Ensure x is a 1D array
291- x = np .asarray (x ).ravel ()
292-
293- # Compute interpolated_x (linear interpolation)
294- interpolated_x = x [idx_int ] * (1 - idx_frac + idx_int ) + x [np .minimum (idx_int + 1 , x_len - 1 )] * (
295- idx_frac - idx_int
296- )
297-
298- # Fill the tail with the last valid value
299- intr_x_tail = np .full ((x_len - len (idx_int ), interpolated_x .shape [1 ]), interpolated_x [- 1 , :])
300- interpolated_x = np .vstack ([interpolated_x , intr_x_tail ])
301-
302- if return_derivatives :
303- # Compute first derivative (d_intr_x)
304- di = - idx_frac / a
305- d_intr_x = x [idx_int ] * (- di ) + x [np .minimum (idx_int + 1 , x_len - 1 )] * di
306- d_intr_x = np .vstack ([d_intr_x , np .zeros ((x_len - len (idx_int ), d_intr_x .shape [1 ]))])
307-
308- # Compute second derivative (dd_intr_x)
309- ddi = - di / a + idx_frac * a ** - 2
310- dd_intr_x = x [idx_int ] * (- ddi ) + x [np .minimum (idx_int + 1 , x_len - 1 )] * ddi
311- dd_intr_x = np .vstack ([dd_intr_x , np .zeros ((x_len - len (idx_int ), dd_intr_x .shape [1 ]))])
312- else :
313- # Make placeholders
314- d_intr_x = np .empty (interpolated_x .shape )
315- dd_intr_x = np .empty (interpolated_x .shape )
316-
317- return interpolated_x , d_intr_x , dd_intr_x
318-
319271 def get_residual_matrix (self , components = None , weights = None , stretch = None ):
320272 # Initialize residual matrix as negative of source_matrix
321273 # In MATLAB this is getR
@@ -722,3 +674,53 @@ def cubic_largest_real_root(p, q):
722674
723675 # Choose correct root depending on sign of delta
724676 return np .where (delta >= 0 , root1 , root2 )
677+
678+
679+ def apply_interpolation (self , a , x , return_derivatives = False ):
680+ """
681+ Applies an interpolation-based transformation to `x` based on scaling `a`.
682+ Also can compute first (`d_intr_x`) and second (`dd_intr_x`) derivatives.
683+ """
684+ x_len = len (x )
685+
686+ # Ensure `a` is an array and reshape for broadcasting
687+ a = np .atleast_1d (np .asarray (a )) # Ensures a is at least 1D
688+
689+ # Compute fractional indices, broadcasting over `a`
690+ fractional_indices = np .arange (x_len )[:, None ] / a # Shape (N, M)
691+
692+ integer_indices = np .floor (fractional_indices ).astype (int ) # Integer part (still (N, M))
693+ valid_mask = integer_indices < (x_len - 1 ) # Ensure indices are within bounds
694+
695+ # Apply valid_mask to keep correct indices
696+ idx_int = np .where (valid_mask , integer_indices , x_len - 2 ) # Prevent out-of-bounds indexing (previously "I")
697+ idx_frac = np .where (valid_mask , fractional_indices , integer_indices ) # Keep aligned (previously "i")
698+
699+ # Ensure x is a 1D array
700+ x = np .asarray (x ).ravel ()
701+
702+ # Compute interpolated_x (linear interpolation)
703+ interpolated_x = x [idx_int ] * (1 - idx_frac + idx_int ) + x [np .minimum (idx_int + 1 , x_len - 1 )] * (
704+ idx_frac - idx_int
705+ )
706+
707+ # Fill the tail with the last valid value
708+ intr_x_tail = np .full ((x_len - len (idx_int ), interpolated_x .shape [1 ]), interpolated_x [- 1 , :])
709+ interpolated_x = np .vstack ([interpolated_x , intr_x_tail ])
710+
711+ if return_derivatives :
712+ # Compute first derivative (d_intr_x)
713+ di = - idx_frac / a
714+ d_intr_x = x [idx_int ] * (- di ) + x [np .minimum (idx_int + 1 , x_len - 1 )] * di
715+ d_intr_x = np .vstack ([d_intr_x , np .zeros ((x_len - len (idx_int ), d_intr_x .shape [1 ]))])
716+
717+ # Compute second derivative (dd_intr_x)
718+ ddi = - di / a + idx_frac * a ** - 2
719+ dd_intr_x = x [idx_int ] * (- ddi ) + x [np .minimum (idx_int + 1 , x_len - 1 )] * ddi
720+ dd_intr_x = np .vstack ([dd_intr_x , np .zeros ((x_len - len (idx_int ), dd_intr_x .shape [1 ]))])
721+ else :
722+ # Make placeholders
723+ d_intr_x = np .empty (interpolated_x .shape )
724+ dd_intr_x = np .empty (interpolated_x .shape )
725+
726+ return interpolated_x , d_intr_x , dd_intr_x
0 commit comments