Skip to content

Commit

Permalink
Merge pull request #6 from ahundt/learn_rate
Browse files Browse the repository at this point in the history
automatically reduce learning rate
  • Loading branch information
titu1994 authored Feb 12, 2017
2 parents 17fbeb1 + c93e321 commit 028bbfb
Showing 1 changed file with 27 additions and 10 deletions.
37 changes: 27 additions & 10 deletions cifar10.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
from __future__ import print_function

import os.path

import densenet
import numpy as np
import sklearn.metrics as metrics
Expand All @@ -8,12 +10,16 @@
from keras.utils import np_utils
from keras.preprocessing.image import ImageDataGenerator
from keras.optimizers import Adam
from keras.callbacks import ModelCheckpoint
from keras.callbacks import CSVLogger, ModelCheckpoint, ReduceLROnPlateau, EarlyStopping
from keras import backend as K


import tensorflow as tf
import datetime

batch_size = 64
nb_classes = 10
nb_epoch = 200
nb_epoch = 300

img_rows, img_cols = 32, 32
img_channels = 3
Expand Down Expand Up @@ -53,14 +59,25 @@
generator.fit(trainX, seed=0)

# Load model
model.load_weights("weights/DenseNet-40-12-CIFAR10.h5")
print("Model loaded.")

# model.fit_generator(generator.flow(trainX, Y_train, batch_size=batch_size), samples_per_epoch=len(trainX), nb_epoch=nb_epoch,
# callbacks=[ModelCheckpoint("weights/DenseNet-40-12-CIFAR10.h5", monitor="val_acc", save_best_only=True,
# save_weights_only=True)],
# validation_data=(testX, Y_test),
# nb_val_samples=testX.shape[0], verbose=2)
weights_file="weights/DenseNet-40-12CIFAR10-tf.h5"
if os.path.exists(weights_file):
model.load_weights(weights_file)
print("Model loaded.")

out_dir="weights/"

lr_reducer = ReduceLROnPlateau(monitor='val_loss', factor=np.sqrt(0.1),
cooldown=0, patience=10, min_lr=0.5e-6)
early_stopper = EarlyStopping(monitor='val_acc', min_delta=0.0001, patience=20)
model_checkpoint= ModelCheckpoint(weights_file, monitor="val_acc", save_best_only=True,
save_weights_only=True,mode='auto')

callbacks=[lr_reducer,early_stopper,model_checkpoint]

model.fit_generator(generator.flow(trainX, Y_train, batch_size=batch_size), samples_per_epoch=len(trainX), nb_epoch=nb_epoch,
callbacks=callbacks,
validation_data=(testX, Y_test),
nb_val_samples=testX.shape[0], verbose=2)

yPreds = model.predict(testX)
yPred = np.argmax(yPreds, axis=1)
Expand Down

0 comments on commit 028bbfb

Please sign in to comment.