44
55
66class  SNMFOptimizer :
7-     """A  implementation of stretched NMF (sNMF), including sparse stretched NMF. 
7+     """An  implementation of stretched NMF (sNMF), including sparse stretched NMF. 
88
99    Instantiating the SNMFOptimizer class runs all the analysis immediately. 
1010    The results matrices can then be accessed as instance attributes 
@@ -117,35 +117,38 @@ def __init__(
117117        self .rho  =  rho 
118118        self .eta  =  eta 
119119        # Capture matrix dimensions 
120-         self ._signal_len , self ._num_conditions  =  source_matrix .shape 
120+         self .signal_length , self .n_signals  =  source_matrix .shape 
121121        self .num_updates  =  0 
122122        self ._rng  =  np .random .default_rng (random_state )
123123
124124        # Enforce exclusive specification of n_components or Y0 
125125        if  (n_components  is  None  and  init_weights  is  None ) or  (
126126            n_components  is  not None  and  init_weights  is  not None 
127127        ):
128-             raise  ValueError ("Must provide exactly one of init_weights or n_components, but not both." )
128+             raise  ValueError (
129+                 "Conflicting source for n_components. Must provide either init_weights or n_components " 
130+                 "directly, but not both." 
131+             )
129132
130133        # Initialize weights and determine number of components 
131134        if  init_weights  is  None :
132-             self ._n_components  =  n_components 
133-             self .weights  =  self ._rng .beta (a = 2.5 , b = 1.5 , size = (self ._n_components , self ._num_conditions ))
135+             self .n_components  =  n_components 
136+             self .weights  =  self ._rng .beta (a = 2.5 , b = 1.5 , size = (self .n_components , self .n_signals ))
134137        else :
135-             self ._n_components  =  init_weights .shape [0 ]
138+             self .n_components  =  init_weights .shape [0 ]
136139            self .weights  =  init_weights 
137140
138141        # Initialize stretching matrix if not provided 
139142        if  init_stretch  is  None :
140-             self .stretch  =  np .ones ((self ._n_components , self ._num_conditions )) +  self ._rng .normal (
141-                 0 , 1e-3 , size = (self ._n_components , self ._num_conditions )
143+             self .stretch  =  np .ones ((self .n_components , self .n_signals )) +  self ._rng .normal (
144+                 0 , 1e-3 , size = (self .n_components , self .n_signals )
142145            )
143146        else :
144147            self .stretch  =  init_stretch 
145148
146149        # Initialize component matrix if not provided 
147150        if  init_components  is  None :
148-             self .components  =  self ._rng .random ((self ._signal_len , self ._n_components ))
151+             self .components  =  self ._rng .random ((self .signal_length , self .n_components ))
149152        else :
150153            self .components  =  init_components 
151154
@@ -155,7 +158,7 @@ def __init__(
155158
156159        # Second-order spline: Tridiagonal (-2 on diagonal, 1 on sub/superdiagonals) 
157160        self .spline_smooth_operator  =  0.25  *  diags (
158-             [1 , - 2 , 1 ], offsets = [0 , 1 , 2 ], shape = (self ._num_conditions  -  2 , self ._num_conditions )
161+             [1 , - 2 , 1 ], offsets = [0 , 1 , 2 ], shape = (self .n_signals  -  2 , self .n_signals )
159162        )
160163        self .spline_smooth_penalty  =  self .spline_smooth_operator .T  @ self .spline_smooth_operator 
161164
@@ -351,34 +354,34 @@ def apply_interpolation_matrix(self, components=None, weights=None, stretch=None
351354            stretch  =  self .stretch 
352355
353356        # Compute scaled indices (MATLAB: AA = repmat(reshape(A',1,M*K).^-1, N,1)) 
354-         stretch_flat  =  stretch .reshape (1 , self ._num_conditions  *  self ._n_components ) **  - 1 
355-         stretch_tiled  =  np .tile (stretch_flat , (self ._signal_len , 1 ))
357+         stretch_flat  =  stretch .reshape (1 , self .n_signals  *  self .n_components ) **  - 1 
358+         stretch_tiled  =  np .tile (stretch_flat , (self .signal_length , 1 ))
356359
357360        # Compute `ii` (MATLAB: ii = repmat((0:N-1)',1,K*M).*tiled_stretch) 
358361        fractional_indices  =  (
359-             np .tile (np .arange (self ._signal_len )[:, None ], (1 , self ._num_conditions  *  self ._n_components ))
362+             np .tile (np .arange (self .signal_length )[:, None ], (1 , self .n_signals  *  self .n_components ))
360363            *  stretch_tiled 
361364        )
362365
363366        # Weighting matrix (MATLAB: YY = repmat(reshape(Y',1,M*K), N,1)) 
364-         weights_flat  =  weights .reshape (1 , self ._num_conditions  *  self ._n_components )
365-         weights_tiled  =  np .tile (weights_flat , (self ._signal_len , 1 ))
367+         weights_flat  =  weights .reshape (1 , self .n_signals  *  self .n_components )
368+         weights_tiled  =  np .tile (weights_flat , (self .signal_length , 1 ))
366369
367370        # Bias for indexing into reshaped X (MATLAB: bias = kron((0:K-1)*(N+1),ones(N,M))) 
368371        # TODO break this up or describe what it does better 
369372        bias  =  np .kron (
370-             np .arange (self ._n_components ) *  (self ._signal_len  +  1 ),
371-             np .ones ((self ._signal_len , self ._num_conditions ), dtype = int ),
372-         ).reshape (self ._signal_len , self ._n_components  *  self ._num_conditions )
373+             np .arange (self .n_components ) *  (self .signal_length  +  1 ),
374+             np .ones ((self .signal_length , self .n_signals ), dtype = int ),
375+         ).reshape (self .signal_length , self .n_components  *  self .n_signals )
373376
374377        # Handle boundary conditions for interpolation (MATLAB: X1=[X;X(end,:)]) 
375378        components_bounded  =  np .vstack ([components , components [- 1 , :]])  # Duplicate last row (like MATLAB) 
376379
377380        # Compute floor indices (MATLAB: II = floor(ii); II1=min(II+1,N+1); II2=min(II1+1,N+1)) 
378381        floor_indices  =  np .floor (fractional_indices ).astype (int )
379382
380-         floor_ind_1  =  np .minimum (floor_indices  +  1 , self ._signal_len )
381-         floor_ind_2  =  np .minimum (floor_ind_1  +  1 , self ._signal_len )
383+         floor_ind_1  =  np .minimum (floor_indices  +  1 , self .signal_length )
384+         floor_ind_2  =  np .minimum (floor_ind_1  +  1 , self .signal_length )
382385
383386        # Compute fractional part (MATLAB: iI = ii - II) 
384387        fractional_floor_indices  =  fractional_indices  -  floor_indices 
@@ -391,10 +394,10 @@ def apply_interpolation_matrix(self, components=None, weights=None, stretch=None
391394        # Note: this "-1" corrects an off-by-one error that may have originated in an earlier line 
392395        # order = F uses FORTRAN, column major order 
393396        components_val_1  =  components_bounded .flatten (order = "F" )[(offset_floor_ind_1  -  1 ).ravel ()].reshape (
394-             self ._signal_len , self ._n_components  *  self ._num_conditions 
397+             self .signal_length , self .n_components  *  self .n_signals 
395398        )
396399        components_val_2  =  components_bounded .flatten (order = "F" )[(offset_floor_ind_2  -  1 ).ravel ()].reshape (
397-             self ._signal_len , self ._n_components  *  self ._num_conditions 
400+             self .signal_length , self .n_components  *  self .n_signals 
398401        )
399402
400403        # Interpolation (MATLAB: Ax2=XI1.*(1-iI)+XI2.*(iI); stretched_components=Ax2.*YY) 
@@ -435,30 +438,30 @@ def apply_transformation_matrix(self, stretch=None, weights=None, residuals=None
435438
436439        # Compute scaling matrix (MATLAB: AA = repmat(reshape(A,1,M*K).^-1,Nindex,1)) 
437440        stretch_tiled  =  np .tile (
438-             stretch .reshape (1 , self ._num_conditions  *  self ._n_components , order = "F" ) **  - 1 , (self ._signal_len , 1 )
441+             stretch .reshape (1 , self .n_signals  *  self .n_components , order = "F" ) **  - 1 , (self .signal_length , 1 )
439442        )
440443
441444        # Compute indices (MATLAB: ii = repmat((index-1)',1,K*M).*AA) 
442-         indices  =  np .arange (self ._signal_len )[:, None ] *  stretch_tiled   # Shape (N, M*K), replacing `index` 
445+         indices  =  np .arange (self .signal_length )[:, None ] *  stretch_tiled   # Shape (N, M*K), replacing `index` 
443446
444447        # Weighting coefficients (MATLAB: YY = repmat(reshape(Y,1,M*K),Nindex,1)) 
445448        weights_tiled  =  np .tile (
446-             weights .reshape (1 , self ._num_conditions  *  self ._n_components , order = "F" ), (self ._signal_len , 1 )
449+             weights .reshape (1 , self .n_signals  *  self .n_components , order = "F" ), (self .signal_length , 1 )
447450        )
448451
449452        # Compute floor indices (MATLAB: II = floor(ii); II1 = min(II+1,N+1); II2 = min(II1+1,N+1)) 
450453        floor_indices  =  np .floor (indices ).astype (int )
451-         floor_indices_1  =  np .minimum (floor_indices  +  1 , self ._signal_len )
452-         floor_indices_2  =  np .minimum (floor_indices_1  +  1 , self ._signal_len )
454+         floor_indices_1  =  np .minimum (floor_indices  +  1 , self .signal_length )
455+         floor_indices_2  =  np .minimum (floor_indices_1  +  1 , self .signal_length )
453456
454457        # Compute fractional part (MATLAB: iI = ii - II) 
455458        fractional_indices  =  indices  -  floor_indices 
456459
457460        # Expand row indices (MATLAB: repm = repmat(1:K, Nindex, M)) 
458-         repm  =  np .tile (np .arange (self ._n_components ), (self ._signal_len , self ._num_conditions ))
461+         repm  =  np .tile (np .arange (self .n_components ), (self .signal_length , self .n_signals ))
459462
460463        # Compute transformations (MATLAB: kro = kron(R(index,:), ones(1, K))) 
461-         kron  =  np .kron (residuals , np .ones ((1 , self ._n_components )))
464+         kron  =  np .kron (residuals , np .ones ((1 , self .n_components )))
462465
463466        # (MATLAB: kroiI = kro .* (iI); iIYY = (iI-1) .* YY) 
464467        fractional_kron  =  kron  *  fractional_indices 
@@ -467,16 +470,16 @@ def apply_transformation_matrix(self, stretch=None, weights=None, residuals=None
467470        # Construct sparse matrices (MATLAB: sparse(II1_,repm,kro.*-iIYY,(N+1),K)) 
468471        x2  =  coo_matrix (
469472            ((- kron  *  fractional_weights ).flatten (), (floor_indices_1 .flatten () -  1 , repm .flatten ())),
470-             shape = (self ._signal_len  +  1 , self ._n_components ),
473+             shape = (self .signal_length  +  1 , self .n_components ),
471474        ).tocsc ()
472475        x3  =  coo_matrix (
473476            ((fractional_kron  *  weights_tiled ).flatten (), (floor_indices_2 .flatten () -  1 , repm .flatten ())),
474-             shape = (self ._signal_len  +  1 , self ._n_components ),
477+             shape = (self .signal_length  +  1 , self .n_components ),
475478        ).tocsc ()
476479
477480        # Combine the last row into previous, then remove the last row 
478-         x2 [self ._signal_len  -  1 , :] +=  x2 [self ._signal_len , :]
479-         x3 [self ._signal_len  -  1 , :] +=  x3 [self ._signal_len , :]
481+         x2 [self .signal_length  -  1 , :] +=  x2 [self .signal_length , :]
482+         x3 [self .signal_length  -  1 , :] +=  x3 [self .signal_length , :]
480483        x2  =  x2 [:- 1 , :]
481484        x3  =  x3 [:- 1 , :]
482485
@@ -543,10 +546,10 @@ def update_components(self):
543546        stretched_components , _ , _  =  self .apply_interpolation_matrix ()  # Skip the other two outputs (derivatives) 
544547        # Compute RA and RR 
545548        intermediate_reshaped  =  stretched_components .flatten (order = "F" ).reshape (
546-             (self ._signal_len  *  self ._num_conditions , self ._n_components ), order = "F" 
549+             (self .signal_length  *  self .n_signals , self .n_components ), order = "F" 
547550        )
548551        reshaped_stretched_components  =  intermediate_reshaped .sum (axis = 1 ).reshape (
549-             (self ._signal_len , self ._num_conditions ), order = "F" 
552+             (self .signal_length , self .n_signals ), order = "F" 
550553        )
551554        component_residuals  =  reshaped_stretched_components  -  self .source_matrix 
552555        # Compute gradient `GraX` 
@@ -603,11 +606,11 @@ def update_weights(self):
603606        Updates weights using matrix operations, solving a quadratic program via to do so. 
604607        """ 
605608
606-         for  m  in  range (self ._num_conditions ):
607-             t  =  np .zeros ((self ._signal_len , self ._n_components ))
609+         for  m  in  range (self .n_signals ):
610+             t  =  np .zeros ((self .signal_length , self .n_components ))
608611
609612            # Populate T using apply_interpolation 
610-             for  k  in  range (self ._n_components ):
613+             for  k  in  range (self .n_components ):
611614                t [:, k ] =  self .apply_interpolation (
612615                    self .stretch [k , m ], self .components [:, k ], return_derivatives = True 
613616                )[0 ].squeeze ()
@@ -635,21 +638,19 @@ def regularize_function(self, stretch=None):
635638
636639        # Compute residual 
637640        intermediate_diff  =  stretched_components .flatten (order = "F" ).reshape (
638-             (self ._signal_len  *  self ._num_conditions , self ._n_components ), order = "F" 
639-         )
640-         stretch_difference  =  intermediate_diff .sum (axis = 1 ).reshape (
641-             (self ._signal_len , self ._num_conditions ), order = "F" 
641+             (self .signal_length  *  self .n_signals , self .n_components ), order = "F" 
642642        )
643+         stretch_difference  =  intermediate_diff .sum (axis = 1 ).reshape ((self .signal_length , self .n_signals ), order = "F" )
643644        stretch_difference  =  stretch_difference  -  self .source_matrix 
644645
645646        # Compute objective function 
646647        reg_func  =  self .get_objective_function (stretch_difference , stretch )
647648
648649        # Compute gradient 
649650        tiled_derivative  =  np .sum (
650-             d_stretch_components  *  np .tile (stretch_difference , (1 , self ._n_components )), axis = 0 
651+             d_stretch_components  *  np .tile (stretch_difference , (1 , self .n_components )), axis = 0 
651652        )
652-         der_reshaped  =  np .asarray (tiled_derivative ).reshape ((self ._num_conditions , self ._n_components ), order = "F" )
653+         der_reshaped  =  np .asarray (tiled_derivative ).reshape ((self .n_signals , self .n_components ), order = "F" )
653654        func_grad  =  (
654655            der_reshaped .T  +  self .rho  *  stretch  @ self .spline_smooth_operator .T  @ self .spline_smooth_operator 
655656        )
0 commit comments