1
+ import os
2
+
1
3
from keras .models import Model
2
4
from keras .layers import Dense , Dropout
3
5
from keras .applications .mobilenet import MobileNet
4
6
from keras .callbacks import ModelCheckpoint , TensorBoard
7
+ from keras .optimizers import Adam
5
8
from keras import backend as K
6
9
7
10
from data_loader import train_generator , val_generator
@@ -47,7 +50,10 @@ def on_epoch_end(self, epoch, logs=None):
47
50
self .writer .flush ()
48
51
49
52
def earth_mover_loss (y_true , y_pred ):
50
- return K .sqrt (K .mean (K .square (K .abs (K .cumsum (y_true , axis = - 1 ) - K .cumsum (y_pred , axis = - 1 )))))
53
+ cdf_ytrue = K .cumsum (y_true , axis = - 1 )
54
+ cdf_ypred = K .cumsum (y_pred , axis = - 1 )
55
+ samplewise_emd = K .sqrt (K .mean (K .square (K .abs (cdf_ytrue - cdf_ypred )), axis = - 1 ))
56
+ return K .mean (samplewise_emd )
51
57
52
58
image_size = 224
53
59
@@ -60,7 +66,12 @@ def earth_mover_loss(y_true, y_pred):
60
66
61
67
model = Model (base_model .input , x )
62
68
model .summary ()
63
- model .compile ('adam' , loss = earth_mover_loss )
69
+ optimizer = Adam (lr = 1e-4 )
70
+ model .compile (optimizer , loss = earth_mover_loss )
71
+
72
+ # load weights from trained model if it exists
73
+ if os .path .exists ('weights/mobilenet_weights.h5' ):
74
+ model .load_weights ('weights/mobilenet_weights.h5' )
64
75
65
76
checkpoint = ModelCheckpoint ('weights/mobilenet_weights.h5' , monitor = 'val_loss' , verbose = 1 , save_weights_only = True , save_best_only = True ,
66
77
mode = 'min' )
0 commit comments