@@ -40,9 +40,8 @@ def __init__(self, args):
4040 self ._eval_every_x_epochs = args .get ("eval_every_x_epochs" )
4141
4242 self ._use_mimic_score = args .get ("mimic_score" )
43- self ._use_less_forget = args .get ("less_forget" )
44- self ._lambda_schedule = args .get ("lambda_schedule" , True )
45- self ._use_ranking = args .get ("ranking_loss" )
43+ self ._less_forget = args .get ("less_forget" )
44+ self ._ranking_loss = args .get ("ranking_loss" )
4645
4746 self ._network = network .BasicNet (
4847 args ["convnet" ],
@@ -62,10 +61,6 @@ def __init__(self, args):
6261
6362 self ._finetuning_config = args .get ("finetuning_config" )
6463
65- self ._lambda = args .get ("base_lambda" , 5 )
66- self ._nb_negatives = args .get ("nb_negatives" , 2 )
67- self ._margin = args .get ("ranking_margin" , 0.2 )
68-
6964 self ._weight_generation = args .get ("weight_generation" )
7065
7166 self ._herding_indexes = []
@@ -79,6 +74,10 @@ def __init__(self, args):
7974 self ._args = args
8075 self ._args ["_logs" ] = {}
8176
77+ self ._during_finetune = False
78+ self ._clip_classifier = None
79+ self ._align_weights_after_epoch = False
80+
8281 def _after_task (self , inc_dataset ):
8382 if "scale" not in self ._args ["_logs" ]:
8483 self ._args ["_logs" ]["scale" ] = []
@@ -205,11 +204,11 @@ def _compute_loss(self, inputs, outputs, targets, onehot_targets, memory_flags):
205204 old_outputs = self ._old_model (inputs )
206205 old_features = old_outputs ["raw_features" ]
207206
208- if self ._use_less_forget :
209- if self ._lambda_schedule :
210- scheduled_lambda = self ._lambda * math .sqrt (self ._n_classes / self ._task_size )
207+ if self ._less_forget :
208+ if self ._less_forget [ "scheduled_factor" ] :
209+ scheduled_lambda = self ._less_forget [ "lambda" ] * math .sqrt (self ._n_classes / self ._task_size )
211210 else :
212- scheduled_lambda = 1.
211+ scheduled_lambda = self . _less_forget [ "lambda" ]
213212
214213 lessforget_loss = scheduled_lambda * losses .embeddings_similarity (
215214 old_features , features
@@ -225,14 +224,14 @@ def _compute_loss(self, inputs, outputs, targets, onehot_targets, memory_flags):
225224 loss += mimic_loss
226225 self ._metrics ["mimic" ] += mimic_loss .item ()
227226
228- if self ._use_ranking :
229- ranking_loss = losses .ucir_ranking (
227+ if self ._ranking_loss :
228+ ranking_loss = self . _ranking_loss [ "factor" ] * losses .ucir_ranking (
230229 logits ,
231230 targets ,
232231 self ._n_classes ,
233232 self ._task_size ,
234- nb_negatives = max (self ._nb_negatives , self ._task_size ),
235- margin = self ._margin
233+ nb_negatives = min (self ._ranking_loss [ "nb_negatives" ] , self ._task_size ),
234+ margin = self ._ranking_loss [ "margin" ]
236235 )
237236 loss += ranking_loss
238237 self ._metrics ["rank" ] += ranking_loss .item ()
0 commit comments