@@ -45,3 +45,75 @@ def test_final_objective_below_threshold(inputs):
4545 # Basic sanity check and the actual assertion
4646 assert np .isfinite (model .objective_function )
4747 assert model .objective_function < 5e6
48+
49+
50+ @pytest .mark .parametrize (
51+ "inputs, expected" ,
52+ # inputs tuple: (components, residuals, stretch, rho, eta, spline smoothness operator)
53+ [
54+ # Case 0: No smoothness or sparsity penalty, reduces to standard NMF objective
55+ # residual Frobenius norm^2 = 3^2 + 4^2 = 25 -> 0.5 * 25 = 12.5
56+ (
57+ (
58+ np .array ([[0.0 , 0.0 ], [3.0 , 4.0 ]]),
59+ np .array ([[0.0 , 0.0 ], [3.0 , 4.0 ]]),
60+ np .ones ((2 , 2 )),
61+ 0.0 ,
62+ 0.0 ,
63+ np .zeros ((2 , 2 )),
64+ ),
65+ 12.5 ,
66+ ),
67+ # Case 1: rho = 0, sparsity penalty only
68+ # sqrt components sum = 1 + 2 + 3 + 4 = 10 -> eta * 10 = 5
69+ # residual term remains 12.5 -> total = 17.5
70+ (
71+ (
72+ np .array ([[1.0 , 4.0 ], [9.0 , 16.0 ]]),
73+ np .array ([[3.0 , 4.0 ], [0.0 , 0.0 ]]),
74+ np .ones ((2 , 2 )),
75+ 0.0 ,
76+ 0.5 ,
77+ np .zeros ((2 , 2 )),
78+ ),
79+ 17.5 ,
80+ ),
81+ # Case 2: eta = 0, smoothness penalty only
82+ # residual = 12.5, smoothing = 0.5 * 1 * 1 = 0.5 -> total = 13.0
83+ (
84+ (
85+ np .array ([[1.0 , 2.0 ], [3.0 , 4.0 ]]),
86+ np .array ([[3.0 , 4.0 ], [0.0 , 0.0 ]]),
87+ np .array ([[1.0 , 2.0 ]]),
88+ 1.0 ,
89+ 0.0 ,
90+ np .array ([[1.0 , - 1.0 ]]),
91+ ),
92+ 13.0 ,
93+ ),
94+ # Case 3: penalty for smoothness and sparsity
95+ # residual = 2.5, sparsity = 1.5, smoothing = 9 -> total = 13.0
96+ (
97+ (
98+ np .array ([[1.0 , 4.0 ]]),
99+ np .array ([[1.0 , 2.0 ]]),
100+ np .array ([[1.0 , 4.0 ]]),
101+ 2.0 ,
102+ 0.5 ,
103+ np .array ([[3.0 , 0.0 ]]),
104+ ),
105+ 13.0 ,
106+ ),
107+ ],
108+ )
109+ def test_compute_objective_function (inputs , expected ):
110+ components , residuals , stretch , rho , eta , operator = inputs
111+ result = SNMFOptimizer ._compute_objective_function (
112+ components = components ,
113+ residuals = residuals ,
114+ stretch = stretch ,
115+ rho = rho ,
116+ eta = eta ,
117+ spline_smooth_operator = operator ,
118+ )
119+ assert np .isclose (result , expected )
0 commit comments