Skip to content

Commit bbd8ac0

Browse files
FIX: evaluate was not working with tracing options + few enh (#280)
1 parent c0f6824 commit bbd8ac0

File tree

1 file changed

+9
-5
lines changed

1 file changed

+9
-5
lines changed

choice_learn/models/base_model.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -285,12 +285,16 @@ def fit(
285285
batch_size = self.batch_size
286286

287287
losses_history = {"train_loss": []}
288-
t_range = tqdm.trange(epochs, position=0)
288+
if verbose >= 0 and verbose < 2:
289+
t_range = tqdm.trange(epochs, position=0)
290+
else:
291+
t_range = range(epochs)
289292

290293
self.callbacks.on_train_begin()
291-
292294
# Iterate of epochs
293295
for epoch_nb in t_range:
296+
if verbose >= 2:
297+
print(f"Start Epoch {epoch_nb}")
294298
self.callbacks.on_epoch_begin(epoch_nb)
295299
t_start = time.time()
296300
train_logs = {"train_loss": []}
@@ -340,7 +344,7 @@ def fit(
340344

341345
if verbose > 0:
342346
inner_range.set_description(
343-
f"Epoch Negative-LogLikeliHood: {np.sum(epoch_losses):.4f}"
347+
f"Epoch Negative-LogLikeliHood: {np.mean(epoch_losses):.4f}"
344348
)
345349

346350
# In this case we do not need to batch the sample_weights
@@ -376,7 +380,7 @@ def fit(
376380

377381
if verbose > 0:
378382
inner_range.set_description(
379-
f"Epoch Negative-LogLikeliHood: {np.sum(epoch_losses):.4f}"
383+
f"Epoch Negative-LogLikeliHood: {np.mean(epoch_losses):.4f}"
380384
)
381385

382386
# Take into account the fact that the last batch may have a
@@ -444,7 +448,7 @@ def fit(
444448
self.callbacks.on_train_end(logs=temps_logs)
445449
return losses_history
446450

447-
@tf.function(reduce_retracing=True)
451+
@tf.function()
448452
def batch_predict(
449453
self,
450454
shared_features_by_choice,

0 commit comments

Comments
 (0)