Skip to content

Commit 915f052

Browse files
author
John Halloran
committed
style: rename and make public certain attributes
1 parent f691b46 commit 915f052

File tree

1 file changed

+45
-44
lines changed

1 file changed

+45
-44
lines changed

src/diffpy/snmf/snmf_class.py

Lines changed: 45 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44

55

66
class SNMFOptimizer:
7-
"""A implementation of stretched NMF (sNMF), including sparse stretched NMF.
7+
"""An implementation of stretched NMF (sNMF), including sparse stretched NMF.
88
99
Instantiating the SNMFOptimizer class runs all the analysis immediately.
1010
The results matrices can then be accessed as instance attributes
@@ -117,35 +117,38 @@ def __init__(
117117
self.rho = rho
118118
self.eta = eta
119119
# Capture matrix dimensions
120-
self._signal_len, self._num_conditions = source_matrix.shape
120+
self.signal_length, self.n_signals = source_matrix.shape
121121
self.num_updates = 0
122122
self._rng = np.random.default_rng(random_state)
123123

124124
# Enforce exclusive specification of n_components or Y0
125125
if (n_components is None and init_weights is None) or (
126126
n_components is not None and init_weights is not None
127127
):
128-
raise ValueError("Must provide exactly one of init_weights or n_components, but not both.")
128+
raise ValueError(
129+
"Conflicting source for n_components. Must provide either init_weights or n_components "
130+
"directly, but not both."
131+
)
129132

130133
# Initialize weights and determine number of components
131134
if init_weights is None:
132-
self._n_components = n_components
133-
self.weights = self._rng.beta(a=2.5, b=1.5, size=(self._n_components, self._num_conditions))
135+
self.n_components = n_components
136+
self.weights = self._rng.beta(a=2.5, b=1.5, size=(self.n_components, self.n_signals))
134137
else:
135-
self._n_components = init_weights.shape[0]
138+
self.n_components = init_weights.shape[0]
136139
self.weights = init_weights
137140

138141
# Initialize stretching matrix if not provided
139142
if init_stretch is None:
140-
self.stretch = np.ones((self._n_components, self._num_conditions)) + self._rng.normal(
141-
0, 1e-3, size=(self._n_components, self._num_conditions)
143+
self.stretch = np.ones((self.n_components, self.n_signals)) + self._rng.normal(
144+
0, 1e-3, size=(self.n_components, self.n_signals)
142145
)
143146
else:
144147
self.stretch = init_stretch
145148

146149
# Initialize component matrix if not provided
147150
if init_components is None:
148-
self.components = self._rng.random((self._signal_len, self._n_components))
151+
self.components = self._rng.random((self.signal_length, self.n_components))
149152
else:
150153
self.components = init_components
151154

@@ -155,7 +158,7 @@ def __init__(
155158

156159
# Second-order spline: Tridiagonal (-2 on diagonal, 1 on sub/superdiagonals)
157160
self.spline_smooth_operator = 0.25 * diags(
158-
[1, -2, 1], offsets=[0, 1, 2], shape=(self._num_conditions - 2, self._num_conditions)
161+
[1, -2, 1], offsets=[0, 1, 2], shape=(self.n_signals - 2, self.n_signals)
159162
)
160163
self.spline_smooth_penalty = self.spline_smooth_operator.T @ self.spline_smooth_operator
161164

@@ -351,34 +354,34 @@ def apply_interpolation_matrix(self, components=None, weights=None, stretch=None
351354
stretch = self.stretch
352355

353356
# Compute scaled indices (MATLAB: AA = repmat(reshape(A',1,M*K).^-1, N,1))
354-
stretch_flat = stretch.reshape(1, self._num_conditions * self._n_components) ** -1
355-
stretch_tiled = np.tile(stretch_flat, (self._signal_len, 1))
357+
stretch_flat = stretch.reshape(1, self.n_signals * self.n_components) ** -1
358+
stretch_tiled = np.tile(stretch_flat, (self.signal_length, 1))
356359

357360
# Compute `ii` (MATLAB: ii = repmat((0:N-1)',1,K*M).*tiled_stretch)
358361
fractional_indices = (
359-
np.tile(np.arange(self._signal_len)[:, None], (1, self._num_conditions * self._n_components))
362+
np.tile(np.arange(self.signal_length)[:, None], (1, self.n_signals * self.n_components))
360363
* stretch_tiled
361364
)
362365

363366
# Weighting matrix (MATLAB: YY = repmat(reshape(Y',1,M*K), N,1))
364-
weights_flat = weights.reshape(1, self._num_conditions * self._n_components)
365-
weights_tiled = np.tile(weights_flat, (self._signal_len, 1))
367+
weights_flat = weights.reshape(1, self.n_signals * self.n_components)
368+
weights_tiled = np.tile(weights_flat, (self.signal_length, 1))
366369

367370
# Bias for indexing into reshaped X (MATLAB: bias = kron((0:K-1)*(N+1),ones(N,M)))
368371
# TODO break this up or describe what it does better
369372
bias = np.kron(
370-
np.arange(self._n_components) * (self._signal_len + 1),
371-
np.ones((self._signal_len, self._num_conditions), dtype=int),
372-
).reshape(self._signal_len, self._n_components * self._num_conditions)
373+
np.arange(self.n_components) * (self.signal_length + 1),
374+
np.ones((self.signal_length, self.n_signals), dtype=int),
375+
).reshape(self.signal_length, self.n_components * self.n_signals)
373376

374377
# Handle boundary conditions for interpolation (MATLAB: X1=[X;X(end,:)])
375378
components_bounded = np.vstack([components, components[-1, :]]) # Duplicate last row (like MATLAB)
376379

377380
# Compute floor indices (MATLAB: II = floor(ii); II1=min(II+1,N+1); II2=min(II1+1,N+1))
378381
floor_indices = np.floor(fractional_indices).astype(int)
379382

380-
floor_ind_1 = np.minimum(floor_indices + 1, self._signal_len)
381-
floor_ind_2 = np.minimum(floor_ind_1 + 1, self._signal_len)
383+
floor_ind_1 = np.minimum(floor_indices + 1, self.signal_length)
384+
floor_ind_2 = np.minimum(floor_ind_1 + 1, self.signal_length)
382385

383386
# Compute fractional part (MATLAB: iI = ii - II)
384387
fractional_floor_indices = fractional_indices - floor_indices
@@ -391,10 +394,10 @@ def apply_interpolation_matrix(self, components=None, weights=None, stretch=None
391394
# Note: this "-1" corrects an off-by-one error that may have originated in an earlier line
392395
# order = F uses FORTRAN, column major order
393396
components_val_1 = components_bounded.flatten(order="F")[(offset_floor_ind_1 - 1).ravel()].reshape(
394-
self._signal_len, self._n_components * self._num_conditions
397+
self.signal_length, self.n_components * self.n_signals
395398
)
396399
components_val_2 = components_bounded.flatten(order="F")[(offset_floor_ind_2 - 1).ravel()].reshape(
397-
self._signal_len, self._n_components * self._num_conditions
400+
self.signal_length, self.n_components * self.n_signals
398401
)
399402

400403
# Interpolation (MATLAB: Ax2=XI1.*(1-iI)+XI2.*(iI); stretched_components=Ax2.*YY)
@@ -435,30 +438,30 @@ def apply_transformation_matrix(self, stretch=None, weights=None, residuals=None
435438

436439
# Compute scaling matrix (MATLAB: AA = repmat(reshape(A,1,M*K).^-1,Nindex,1))
437440
stretch_tiled = np.tile(
438-
stretch.reshape(1, self._num_conditions * self._n_components, order="F") ** -1, (self._signal_len, 1)
441+
stretch.reshape(1, self.n_signals * self.n_components, order="F") ** -1, (self.signal_length, 1)
439442
)
440443

441444
# Compute indices (MATLAB: ii = repmat((index-1)',1,K*M).*AA)
442-
indices = np.arange(self._signal_len)[:, None] * stretch_tiled # Shape (N, M*K), replacing `index`
445+
indices = np.arange(self.signal_length)[:, None] * stretch_tiled # Shape (N, M*K), replacing `index`
443446

444447
# Weighting coefficients (MATLAB: YY = repmat(reshape(Y,1,M*K),Nindex,1))
445448
weights_tiled = np.tile(
446-
weights.reshape(1, self._num_conditions * self._n_components, order="F"), (self._signal_len, 1)
449+
weights.reshape(1, self.n_signals * self.n_components, order="F"), (self.signal_length, 1)
447450
)
448451

449452
# Compute floor indices (MATLAB: II = floor(ii); II1 = min(II+1,N+1); II2 = min(II1+1,N+1))
450453
floor_indices = np.floor(indices).astype(int)
451-
floor_indices_1 = np.minimum(floor_indices + 1, self._signal_len)
452-
floor_indices_2 = np.minimum(floor_indices_1 + 1, self._signal_len)
454+
floor_indices_1 = np.minimum(floor_indices + 1, self.signal_length)
455+
floor_indices_2 = np.minimum(floor_indices_1 + 1, self.signal_length)
453456

454457
# Compute fractional part (MATLAB: iI = ii - II)
455458
fractional_indices = indices - floor_indices
456459

457460
# Expand row indices (MATLAB: repm = repmat(1:K, Nindex, M))
458-
repm = np.tile(np.arange(self._n_components), (self._signal_len, self._num_conditions))
461+
repm = np.tile(np.arange(self.n_components), (self.signal_length, self.n_signals))
459462

460463
# Compute transformations (MATLAB: kro = kron(R(index,:), ones(1, K)))
461-
kron = np.kron(residuals, np.ones((1, self._n_components)))
464+
kron = np.kron(residuals, np.ones((1, self.n_components)))
462465

463466
# (MATLAB: kroiI = kro .* (iI); iIYY = (iI-1) .* YY)
464467
fractional_kron = kron * fractional_indices
@@ -467,16 +470,16 @@ def apply_transformation_matrix(self, stretch=None, weights=None, residuals=None
467470
# Construct sparse matrices (MATLAB: sparse(II1_,repm,kro.*-iIYY,(N+1),K))
468471
x2 = coo_matrix(
469472
((-kron * fractional_weights).flatten(), (floor_indices_1.flatten() - 1, repm.flatten())),
470-
shape=(self._signal_len + 1, self._n_components),
473+
shape=(self.signal_length + 1, self.n_components),
471474
).tocsc()
472475
x3 = coo_matrix(
473476
((fractional_kron * weights_tiled).flatten(), (floor_indices_2.flatten() - 1, repm.flatten())),
474-
shape=(self._signal_len + 1, self._n_components),
477+
shape=(self.signal_length + 1, self.n_components),
475478
).tocsc()
476479

477480
# Combine the last row into previous, then remove the last row
478-
x2[self._signal_len - 1, :] += x2[self._signal_len, :]
479-
x3[self._signal_len - 1, :] += x3[self._signal_len, :]
481+
x2[self.signal_length - 1, :] += x2[self.signal_length, :]
482+
x3[self.signal_length - 1, :] += x3[self.signal_length, :]
480483
x2 = x2[:-1, :]
481484
x3 = x3[:-1, :]
482485

@@ -543,10 +546,10 @@ def update_components(self):
543546
stretched_components, _, _ = self.apply_interpolation_matrix() # Skip the other two outputs (derivatives)
544547
# Compute RA and RR
545548
intermediate_reshaped = stretched_components.flatten(order="F").reshape(
546-
(self._signal_len * self._num_conditions, self._n_components), order="F"
549+
(self.signal_length * self.n_signals, self.n_components), order="F"
547550
)
548551
reshaped_stretched_components = intermediate_reshaped.sum(axis=1).reshape(
549-
(self._signal_len, self._num_conditions), order="F"
552+
(self.signal_length, self.n_signals), order="F"
550553
)
551554
component_residuals = reshaped_stretched_components - self.source_matrix
552555
# Compute gradient `GraX`
@@ -603,11 +606,11 @@ def update_weights(self):
603606
Updates weights using matrix operations, solving a quadratic program via to do so.
604607
"""
605608

606-
for m in range(self._num_conditions):
607-
t = np.zeros((self._signal_len, self._n_components))
609+
for m in range(self.n_signals):
610+
t = np.zeros((self.signal_length, self.n_components))
608611

609612
# Populate T using apply_interpolation
610-
for k in range(self._n_components):
613+
for k in range(self.n_components):
611614
t[:, k] = self.apply_interpolation(
612615
self.stretch[k, m], self.components[:, k], return_derivatives=True
613616
)[0].squeeze()
@@ -635,21 +638,19 @@ def regularize_function(self, stretch=None):
635638

636639
# Compute residual
637640
intermediate_diff = stretched_components.flatten(order="F").reshape(
638-
(self._signal_len * self._num_conditions, self._n_components), order="F"
639-
)
640-
stretch_difference = intermediate_diff.sum(axis=1).reshape(
641-
(self._signal_len, self._num_conditions), order="F"
641+
(self.signal_length * self.n_signals, self.n_components), order="F"
642642
)
643+
stretch_difference = intermediate_diff.sum(axis=1).reshape((self.signal_length, self.n_signals), order="F")
643644
stretch_difference = stretch_difference - self.source_matrix
644645

645646
# Compute objective function
646647
reg_func = self.get_objective_function(stretch_difference, stretch)
647648

648649
# Compute gradient
649650
tiled_derivative = np.sum(
650-
d_stretch_components * np.tile(stretch_difference, (1, self._n_components)), axis=0
651+
d_stretch_components * np.tile(stretch_difference, (1, self.n_components)), axis=0
651652
)
652-
der_reshaped = np.asarray(tiled_derivative).reshape((self._num_conditions, self._n_components), order="F")
653+
der_reshaped = np.asarray(tiled_derivative).reshape((self.n_signals, self.n_components), order="F")
653654
func_grad = (
654655
der_reshaped.T + self.rho * stretch @ self.spline_smooth_operator.T @ self.spline_smooth_operator
655656
)

0 commit comments

Comments
 (0)