Skip to content

Commit 4ff1eb0

Browse files
author
John Halloran
committed
style: remove non-class function from class
1 parent fdbacb3 commit 4ff1eb0

File tree

1 file changed

+59
-57
lines changed

1 file changed

+59
-57
lines changed

src/diffpy/snmf/snmf_class.py

Lines changed: 59 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -113,13 +113,15 @@ def __init__(
113113
The number of components to extract from source_matrix. Must be provided when and only when
114114
init_weights is not provided.
115115
random_state : int Optional Default = None
116-
The seed for the initial guesses at the matrices (A, X, and Y) created by
116+
The seed for the initial guesses at the matrices (stretch, components, and weights) created by
117117
the decomposition.
118118
"""
119119

120120
self.source_matrix = source_matrix
121121
self.rho = rho
122122
self.eta = eta
123+
self.tol = tol
124+
self.max_iter = max_iter
123125
# Capture matrix dimensions
124126
self.signal_length, self.n_signals = source_matrix.shape
125127
self.num_updates = 0
@@ -164,18 +166,19 @@ def __init__(
164166
self._spline_smooth_operator = 0.25 * diags(
165167
[1, -2, 1], offsets=[0, 1, 2], shape=(self.n_signals - 2, self.n_signals)
166168
)
167-
self._spline_smooth_penalty = self._spline_smooth_operator.T @ self._spline_smooth_operator
168169

169170
# Set up residual matrix, objective function, and history
170171
self.residuals = self.get_residual_matrix()
171-
self._objective_history = []
172-
self.update_objective()
172+
self.objective_function = self.get_objective_function()
173+
self.best_objective = self.objective_function
174+
self.best_matrices = [self.components.copy(), self.weights.copy(), self.stretch.copy()]
173175
self.objective_difference = None
176+
self._objective_history = [self.objective_function]
174177

175178
# Set up tracking variables for update_components()
176179
self._prev_components = None
177-
self.grad_components = np.zeros_like(self.components) # Gradient of X (zeros for now)
178-
self._prev_grad_components = np.zeros_like(self.components) # Previous gradient of X (zeros for now)
180+
self.grad_components = np.zeros_like(self.components)
181+
self._prev_grad_components = np.zeros_like(self.components)
179182

180183
regularization_term = 0.5 * rho * np.linalg.norm(self._spline_smooth_operator @ self.stretch.T, "fro") ** 2
181184
sparsity_term = eta * np.sum(np.sqrt(self.components)) # Square root penalty
@@ -265,57 +268,6 @@ def optimize_loop(self):
265268

266269
self.objective_difference = self._objective_history[-2] - self._objective_history[-1]
267270

268-
def apply_interpolation(self, a, x, return_derivatives=False):
269-
"""
270-
Applies an interpolation-based transformation to `x` based on scaling `a`.
271-
Also can compute first (`d_intr_x`) and second (`dd_intr_x`) derivatives.
272-
"""
273-
x_len = len(x)
274-
275-
# Ensure `a` is an array and reshape for broadcasting
276-
a = np.atleast_1d(np.asarray(a)) # Ensures a is at least 1D
277-
278-
# Compute fractional indices, broadcasting over `a`
279-
fractional_indices = np.arange(x_len)[:, None] / a # Shape (N, M)
280-
281-
integer_indices = np.floor(fractional_indices).astype(int) # Integer part (still (N, M))
282-
valid_mask = integer_indices < (x_len - 1) # Ensure indices are within bounds
283-
284-
# Apply valid_mask to keep correct indices
285-
idx_int = np.where(
286-
valid_mask, integer_indices, x_len - 2
287-
) # Prevent out-of-bounds indexing (previously "I")
288-
idx_frac = np.where(valid_mask, fractional_indices, integer_indices) # Keep aligned (previously "i")
289-
290-
# Ensure x is a 1D array
291-
x = np.asarray(x).ravel()
292-
293-
# Compute interpolated_x (linear interpolation)
294-
interpolated_x = x[idx_int] * (1 - idx_frac + idx_int) + x[np.minimum(idx_int + 1, x_len - 1)] * (
295-
idx_frac - idx_int
296-
)
297-
298-
# Fill the tail with the last valid value
299-
intr_x_tail = np.full((x_len - len(idx_int), interpolated_x.shape[1]), interpolated_x[-1, :])
300-
interpolated_x = np.vstack([interpolated_x, intr_x_tail])
301-
302-
if return_derivatives:
303-
# Compute first derivative (d_intr_x)
304-
di = -idx_frac / a
305-
d_intr_x = x[idx_int] * (-di) + x[np.minimum(idx_int + 1, x_len - 1)] * di
306-
d_intr_x = np.vstack([d_intr_x, np.zeros((x_len - len(idx_int), d_intr_x.shape[1]))])
307-
308-
# Compute second derivative (dd_intr_x)
309-
ddi = -di / a + idx_frac * a**-2
310-
dd_intr_x = x[idx_int] * (-ddi) + x[np.minimum(idx_int + 1, x_len - 1)] * ddi
311-
dd_intr_x = np.vstack([dd_intr_x, np.zeros((x_len - len(idx_int), dd_intr_x.shape[1]))])
312-
else:
313-
# Make placeholders
314-
d_intr_x = np.empty(interpolated_x.shape)
315-
dd_intr_x = np.empty(interpolated_x.shape)
316-
317-
return interpolated_x, d_intr_x, dd_intr_x
318-
319271
def get_residual_matrix(self, components=None, weights=None, stretch=None):
320272
# Initialize residual matrix as negative of source_matrix
321273
# In MATLAB this is getR
@@ -722,3 +674,53 @@ def cubic_largest_real_root(p, q):
722674

723675
# Choose correct root depending on sign of delta
724676
return np.where(delta >= 0, root1, root2)
677+
678+
679+
def apply_interpolation(self, a, x, return_derivatives=False):
680+
"""
681+
Applies an interpolation-based transformation to `x` based on scaling `a`.
682+
Also can compute first (`d_intr_x`) and second (`dd_intr_x`) derivatives.
683+
"""
684+
x_len = len(x)
685+
686+
# Ensure `a` is an array and reshape for broadcasting
687+
a = np.atleast_1d(np.asarray(a)) # Ensures a is at least 1D
688+
689+
# Compute fractional indices, broadcasting over `a`
690+
fractional_indices = np.arange(x_len)[:, None] / a # Shape (N, M)
691+
692+
integer_indices = np.floor(fractional_indices).astype(int) # Integer part (still (N, M))
693+
valid_mask = integer_indices < (x_len - 1) # Ensure indices are within bounds
694+
695+
# Apply valid_mask to keep correct indices
696+
idx_int = np.where(valid_mask, integer_indices, x_len - 2) # Prevent out-of-bounds indexing (previously "I")
697+
idx_frac = np.where(valid_mask, fractional_indices, integer_indices) # Keep aligned (previously "i")
698+
699+
# Ensure x is a 1D array
700+
x = np.asarray(x).ravel()
701+
702+
# Compute interpolated_x (linear interpolation)
703+
interpolated_x = x[idx_int] * (1 - idx_frac + idx_int) + x[np.minimum(idx_int + 1, x_len - 1)] * (
704+
idx_frac - idx_int
705+
)
706+
707+
# Fill the tail with the last valid value
708+
intr_x_tail = np.full((x_len - len(idx_int), interpolated_x.shape[1]), interpolated_x[-1, :])
709+
interpolated_x = np.vstack([interpolated_x, intr_x_tail])
710+
711+
if return_derivatives:
712+
# Compute first derivative (d_intr_x)
713+
di = -idx_frac / a
714+
d_intr_x = x[idx_int] * (-di) + x[np.minimum(idx_int + 1, x_len - 1)] * di
715+
d_intr_x = np.vstack([d_intr_x, np.zeros((x_len - len(idx_int), d_intr_x.shape[1]))])
716+
717+
# Compute second derivative (dd_intr_x)
718+
ddi = -di / a + idx_frac * a**-2
719+
dd_intr_x = x[idx_int] * (-ddi) + x[np.minimum(idx_int + 1, x_len - 1)] * ddi
720+
dd_intr_x = np.vstack([dd_intr_x, np.zeros((x_len - len(idx_int), dd_intr_x.shape[1]))])
721+
else:
722+
# Make placeholders
723+
d_intr_x = np.empty(interpolated_x.shape)
724+
dd_intr_x = np.empty(interpolated_x.shape)
725+
726+
return interpolated_x, d_intr_x, dd_intr_x

0 commit comments

Comments
 (0)