Skip to content

Commit 6adae6f

Browse files
author
John Halloran
committed
refactor: change get_objective_function into a static method and getter
1 parent 9d4f98b commit 6adae6f

File tree

2 files changed

+95
-0
lines changed

2 files changed

+95
-0
lines changed

news/declass-obj.rst

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
**Added:**
2+
3+
* Implement tests for ``compute_objective_function()``
4+
5+
**Changed:**
6+
7+
* Refactor ``get_objective_function()`` into a static method and getter
8+
9+
**Deprecated:**
10+
11+
* <news item>
12+
13+
**Removed:**
14+
15+
* <news item>
16+
17+
**Fixed:**
18+
19+
* <news item>
20+
21+
**Security:**
22+
23+
* <news item>

tests/test_snmf_optimizer.py

Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,3 +45,75 @@ def test_final_objective_below_threshold(inputs):
4545
# Basic sanity check and the actual assertion
4646
assert np.isfinite(model.objective_function)
4747
assert model.objective_function < 5e6
48+
49+
50+
@pytest.mark.parametrize(
51+
"inputs, expected",
52+
# inputs tuple: (components, residuals, stretch, rho, eta, spline smoothness operator)
53+
[
54+
# Case 0: No smoothness or sparsity penalty, reduces to standard NMF objective
55+
# residual Frobenius norm^2 = 3^2 + 4^2 = 25 -> 0.5 * 25 = 12.5
56+
(
57+
(
58+
np.array([[0.0, 0.0], [3.0, 4.0]]),
59+
np.array([[0.0, 0.0], [3.0, 4.0]]),
60+
np.ones((2, 2)),
61+
0.0,
62+
0.0,
63+
np.zeros((2, 2)),
64+
),
65+
12.5,
66+
),
67+
# Case 1: rho = 0, sparsity penalty only
68+
# sqrt components sum = 1 + 2 + 3 + 4 = 10 -> eta * 10 = 5
69+
# residual term remains 12.5 -> total = 17.5
70+
(
71+
(
72+
np.array([[1.0, 4.0], [9.0, 16.0]]),
73+
np.array([[3.0, 4.0], [0.0, 0.0]]),
74+
np.ones((2, 2)),
75+
0.0,
76+
0.5,
77+
np.zeros((2, 2)),
78+
),
79+
17.5,
80+
),
81+
# Case 2: eta = 0, smoothness penalty only
82+
# residual = 12.5, smoothing = 0.5 * 1 * 1 = 0.5 -> total = 13.0
83+
(
84+
(
85+
np.array([[1.0, 2.0], [3.0, 4.0]]),
86+
np.array([[3.0, 4.0], [0.0, 0.0]]),
87+
np.array([[1.0, 2.0]]),
88+
1.0,
89+
0.0,
90+
np.array([[1.0, -1.0]]),
91+
),
92+
13.0,
93+
),
94+
# Case 3: penalty for smoothness and sparsity
95+
# residual = 2.5, sparsity = 1.5, smoothing = 9 -> total = 13.0
96+
(
97+
(
98+
np.array([[1.0, 4.0]]),
99+
np.array([[1.0, 2.0]]),
100+
np.array([[1.0, 4.0]]),
101+
2.0,
102+
0.5,
103+
np.array([[3.0, 0.0]]),
104+
),
105+
13.0,
106+
),
107+
],
108+
)
109+
def test_compute_objective_function(inputs, expected):
110+
components, residuals, stretch, rho, eta, operator = inputs
111+
result = SNMFOptimizer._compute_objective_function(
112+
components=components,
113+
residuals=residuals,
114+
stretch=stretch,
115+
rho=rho,
116+
eta=eta,
117+
spline_smooth_operator=operator,
118+
)
119+
assert np.isclose(result, expected)

0 commit comments

Comments
 (0)