Skip to content

Commit 9d4f98b

Browse files
author
John Halloran
committed
refactor: compute objective function in a static method and retrieve via getter
1 parent d83619e commit 9d4f98b

File tree

1 file changed

+77
-10
lines changed

1 file changed

+77
-10
lines changed

src/diffpy/snmf/snmf_class.py

Lines changed: 77 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -359,16 +359,30 @@ def get_residual_matrix(self, components=None, weights=None, stretch=None):
359359
return residuals
360360

361361
def get_objective_function(self, residuals=None, stretch=None):
362-
if residuals is None:
363-
residuals = self.residuals
364-
if stretch is None:
365-
stretch = self.stretch_
366-
residual_term = 0.5 * np.linalg.norm(residuals, "fro") ** 2
367-
regularization_term = 0.5 * self.rho * np.linalg.norm(self._spline_smooth_operator @ stretch.T, "fro") ** 2
368-
sparsity_term = self.eta * np.sum(np.sqrt(self.components_)) # Square root penalty
369-
# Final objective function value
370-
function = residual_term + regularization_term + sparsity_term
371-
return function
362+
"""
363+
Return the objective value, passing stored attributes or overrides
364+
to _compute_objective_function().
365+
366+
Parameters
367+
----------
368+
residuals : ndarray, optional
369+
Residual matrix to use instead of self.residuals.
370+
stretch : ndarray, optional
371+
Stretch matrix to use instead of self.stretch_.
372+
373+
Returns
374+
-------
375+
float
376+
Current objective function value.
377+
"""
378+
return SNMFOptimizer._compute_objective_function(
379+
components=self.components_,
380+
residuals=self.residuals if residuals is None else residuals,
381+
stretch=self.stretch_ if stretch is None else stretch,
382+
rho=self.rho,
383+
eta=self.eta,
384+
spline_smooth_operator=self._spline_smooth_operator,
385+
)
372386

373387
def compute_stretched_components(self, components=None, weights=None, stretch=None):
374388
"""
@@ -702,6 +716,59 @@ def objective(stretch_vec):
702716
# Update stretch with the optimized values
703717
self.stretch_ = result.x.reshape(self.stretch_.shape)
704718

719+
@staticmethod
720+
def _compute_objective_function(components, residuals, stretch, rho, eta, spline_smooth_operator):
721+
r"""
722+
Computes the objective function used in stretched non-negative matrix factorization.
723+
724+
Parameters
725+
----------
726+
components : ndarray
727+
Non-negative matrix of component signals :math:`X`.
728+
residuals : ndarray
729+
Difference between reconstructed and observed data.
730+
stretch : ndarray
731+
Stretching factors :math:`A` applied to each component across samples.
732+
rho : float
733+
Regularization parameter enforcing smooth variation in :math:`A`.
734+
eta : float
735+
Sparsity-promoting regularization parameter applied to :math:`X`.
736+
spline_smooth_operator : ndarray
737+
Linear operator :math:`L` penalizing non-smooth changes in :math:`A`.
738+
739+
Returns
740+
-------
741+
float
742+
Value of the stretched-NMF objective function.
743+
744+
Notes
745+
-----
746+
The stretched-NMF objective function :math:`J` is
747+
748+
.. math::
749+
750+
J(X, Y, A) =
751+
\tfrac{1}{2} \lVert Z - Y\,S(A)X \rVert_F^2
752+
+ \tfrac{\rho}{2} \lVert L A \rVert_F^2
753+
+ \eta \sum_{i,j} \sqrt{X_{ij}} \,,
754+
755+
where :math:`Z` is the data matrix, :math:`Y` contains the non-negative
756+
weights, :math:`S(A)` denotes the spline-interpolated stretching operator,
757+
and :math:`\lVert \cdot \rVert_F` is the Frobenius norm.
758+
759+
Special cases
760+
-------------
761+
- :math:`\rho = 0` — no smoothness regularization on stretching factors.
762+
- :math:`\eta = 0` — no sparsity promotion on components.
763+
- :math:`\rho = \eta = 0` — reduces to the classical NMF least-squares
764+
objective :math:`\tfrac{1}{2} \lVert Z - YX \rVert_F^2`.
765+
766+
"""
767+
residual_term = 0.5 * np.linalg.norm(residuals, "fro") ** 2
768+
regularization_term = 0.5 * rho * np.linalg.norm(spline_smooth_operator @ stretch.T, "fro") ** 2
769+
sparsity_term = eta * np.sum(np.sqrt(components))
770+
return residual_term + regularization_term + sparsity_term
771+
705772

706773
def cubic_largest_real_root(p, q):
707774
"""

0 commit comments

Comments
 (0)