Skip to content

Commit 1177b4f

Browse files
committed
Update weights and EMD loss
1 parent 20df435 commit 1177b4f

File tree

3 files changed

+14
-4
lines changed

3 files changed

+14
-4
lines changed

evaluate_mobilenet.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,10 +18,9 @@
1818
model = Model(base_model.input, x)
1919
model.load_weights('weights/mobilenet_weights.h5')
2020

21-
img_path = 'images/img.png'
21+
img_path = 'images/art1.jpg'
2222
img = load_img(img_path)
2323
x = img_to_array(img)
24-
2524
x = np.expand_dims(x, axis=0)
2625

2726
x = preprocess_input(x)

train_mobilenet.py

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,10 @@
1+
import os
2+
13
from keras.models import Model
24
from keras.layers import Dense, Dropout
35
from keras.applications.mobilenet import MobileNet
46
from keras.callbacks import ModelCheckpoint, TensorBoard
7+
from keras.optimizers import Adam
58
from keras import backend as K
69

710
from data_loader import train_generator, val_generator
@@ -47,7 +50,10 @@ def on_epoch_end(self, epoch, logs=None):
4750
self.writer.flush()
4851

4952
def earth_mover_loss(y_true, y_pred):
50-
return K.sqrt(K.mean(K.square(K.abs(K.cumsum(y_true, axis=-1) - K.cumsum(y_pred, axis=-1)))))
53+
cdf_ytrue = K.cumsum(y_true, axis=-1)
54+
cdf_ypred = K.cumsum(y_pred, axis=-1)
55+
samplewise_emd = K.sqrt(K.mean(K.square(K.abs(cdf_ytrue - cdf_ypred)), axis=-1))
56+
return K.mean(samplewise_emd)
5157

5258
image_size = 224
5359

@@ -60,7 +66,12 @@ def earth_mover_loss(y_true, y_pred):
6066

6167
model = Model(base_model.input, x)
6268
model.summary()
63-
model.compile('adam', loss=earth_mover_loss)
69+
optimizer = Adam(lr=1e-4)
70+
model.compile(optimizer, loss=earth_mover_loss)
71+
72+
# load weights from trained model if it exists
73+
if os.path.exists('weights/mobilenet_weights.h5'):
74+
model.load_weights('weights/mobilenet_weights.h5')
6475

6576
checkpoint = ModelCheckpoint('weights/mobilenet_weights.h5', monitor='val_loss', verbose=1, save_weights_only=True, save_best_only=True,
6677
mode='min')

weights/mobilenet_weights.h5

0 Bytes
Binary file not shown.

0 commit comments

Comments
 (0)