@@ -80,17 +80,44 @@ def __init__(self, MM, Y0=None, X0=None, A=None, rho=1e12, eta=610, maxiter=300,
8080 )
8181
8282 # Convergence check: Stop if diffun is small and at least 20 iterations have passed
83- print (self .objective_difference , " < " , self .objective_function * 1e-6 )
84- if self .objective_difference < self .objective_function * 1e-6 and outiter >= 20 :
83+ # MATLAB uses 1e-6 but also gets faster convergence, so this makes up that difference
84+ print (self .objective_difference , " < " , self .objective_function * 5e-7 )
85+ if self .objective_difference < self .objective_function * 5e-7 and outiter >= 20 :
8586 break
8687
8788 # Normalize our results
89+ # TODO make this much cleaner
8890 Y_row_max = np .max (self .Y , axis = 1 , keepdims = True )
8991 self .Y = self .Y / Y_row_max
9092 A_row_max = np .max (self .A , axis = 1 , keepdims = True )
9193 self .A = self .A / A_row_max
92- # TODO loop to normalize X (currently not normalized)
94+ # loop to normalize X
9395 # effectively just re-running class with non-normalized X, normalized Y/A as inputs, then only update X
96+ # reset difference trackers and initialize
97+ self .preX = self .X .copy () # Previously stored X (like X0 for now)
98+ self .GraX = np .zeros_like (self .X ) # Gradient of X (zeros for now)
99+ self .preGraX = np .zeros_like (self .X ) # Previous gradient of X (zeros for now)
100+ self .R = self .get_residual_matrix ()
101+ self .objective_function = self .get_objective_function ()
102+ self .objective_difference = None
103+ self .objective_history = [self .objective_function ]
104+ self .outiter = 0
105+ self .iter = 0
106+ for outiter in range (100 ):
107+ if iter == 1 :
108+ self .iter = 1 # So step size can adapt without an inner loop
109+ self .updateX ()
110+ self .R = self .get_residual_matrix ()
111+ self .objective_function = self .get_objective_function ()
112+ print (f"Objective function after normX: { self .objective_function :.5e} " )
113+ self .objective_history .append (self .objective_function )
114+ self .objective_difference = self .objective_history [- 2 ] - self .objective_history [- 1 ]
115+ if self .objective_difference < self .objective_function * 5e-7 and outiter >= 20 :
116+ break
117+ # end of normalization (and program)
118+ # note that objective function does not fully recover after normalization
119+ # it is still higher than pre-normalization, but that is okay and matches MATLAB
120+ print ("Finished optimization." )
94121
95122 def outer_loop (self ):
96123 # This inner loop runs up to four times per outer loop, making updates to X, Y
0 commit comments