Skip to content

Commit

Permalink
refactored code
Browse files Browse the repository at this point in the history
  • Loading branch information
BKHMSI committed Oct 8, 2018
1 parent 1c94351 commit ca881ac
Show file tree
Hide file tree
Showing 6 changed files with 154 additions and 100 deletions.
9 changes: 6 additions & 3 deletions config.yaml
Original file line number Diff line number Diff line change
@@ -1,18 +1,21 @@
run-title: "SoftmaxMnist"

paths:
save: "Models/1"
load: ""
save: "Models/13"
load: "Models/11/model.78-0.0392.h5"

train:
lr: 0.001
optim: "Nadam"

epochs: 100
batch-size: 250
batch-size: 400

loss: "categorical-crossentropy"
alpha: 0.2
beta: 0.1
scale: 64
reg_lambda: 0.01

lr_reduce_factor: 0.5
patience: 5
Expand Down
27 changes: 23 additions & 4 deletions data.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ def load(self):

self.X_train = self.preprocess(X_train)
self.X_test = self.preprocess(X_test)

if self.one_hot:
self.y_train = np_utils.to_categorical(self.y_train, self.config["num_classes"])
self.y_test = np_utils.to_categorical(self.y_test, self.config["num_classes"])
Expand All @@ -33,10 +33,29 @@ def load(self):

def preprocess(self, data):
data = data.astype('float32')
data = data - self.mean
data = data / self.std
return data
# data = data - self.mean
# data = data / self.std
return data / 255.

def order_data_triplet_loss(self):
data = {}

for label in range(self.config["num_classes"]):
mask = self.y_train==label
data[label] = [i for i, x in enumerate(mask) if x]

p_batch = self.config["batch-size"] // self.config["k_batch"]
k_batch = self.config["k_batch"]

X_train, y_train = [], []
for i in range(p_batch):
for label in data:
X_train.extend(self.X_train[data[label][i*k_batch:(i+1)*k_batch]])
y_train += [label] * k_batch

self.X_train = X_train
self.y_train = y_train

def get_random_batch(self, k = 100):
X_batch, y_batch = [], []
for label in range(self.config["num_classes"]):
Expand Down
49 changes: 43 additions & 6 deletions loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,20 +31,27 @@ def __semi_hard_triplet_loss(labels, embeddings, margin = 0.2):
def __intra_enhanced_triplet_loss(labels, embeddings, lambda_1, alpha, beta, batch_size, k):
return tf.add(__semi_hard_triplet_loss(labels, embeddings, alpha), tf.multiply(lambda_1, __anchor_center_loss(embeddings, beta, batch_size, k)))

def __large_margin_cos_loss(labels, embeddings, alpha, scale, num_cls):
def __large_margin_cos_loss(labels, embeddings, alpha, scale, regularization_lambda, num_cls = 10):
num_features = embeddings.get_shape()[1]

weights = tf.get_variable("centers", [num_features, num_cls], dtype=tf.float32,
initializer=tf.contrib.layers.xavier_initializer(), trainable=True)
with tf.variable_scope('centers_scope', reuse = tf.AUTO_REUSE):
weights = tf.get_variable("centers", [num_features, num_cls], dtype=tf.float32,
initializer=tf.contrib.layers.xavier_initializer(), regularizer=tf.contrib.layers.l2_regularizer(1e-4), trainable=True)

embedds_feat_norm = tf.nn.l2_normalize(embeddings, 1, 1e-10)
weights_feat_norm = tf.nn.l2_normalize(weights, 0, 1e-10)

xw_norm = tf.matmul(embedds_feat_norm, weights_feat_norm)
margin_xw_norm = xw_norm - alpha

# value = something
labels = tf.squeeze(tf.cast(labels, tf.int32))
label_onehot = tf.one_hot(labels, num_cls)
value = scale*tf.where(tf.equal(label_onehot, 1), margin_xw_norm, xw_norm)

cos_loss = tf.reduce_mean(tf.nn.sparse_softmax_cross_entropy_with_logits(labels=labels, logits=value))

regularization_losses = tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES)
cos_loss = cos_loss + regularization_lambda * tf.add_n(regularization_losses)
return cos_loss

def semi_hard_triplet_loss(margin):
Expand All @@ -62,5 +69,35 @@ def loss(labels, embeddings):
def large_margin_cos_loss(config):
@functools.wraps(__large_margin_cos_loss)
def loss(labels, embeddings):
return __large_margin_cos_loss(labels, embeddings, config["alpha"], config["scale"])
return loss
return __large_margin_cos_loss(labels, embeddings, config["alpha"], config["scale"], config["reg_lambda"])
return loss

def __large_margin_cos_acc(labels, embeddings, alpha, scale, num_cls = 10):
num_features = embeddings.get_shape()[1]

with tf.variable_scope('centers_scope', reuse = tf.AUTO_REUSE):
weights = tf.get_variable("centers", [num_features, num_cls], dtype=tf.float32,
initializer=tf.contrib.layers.xavier_initializer(), trainable=True)

embedds_feat_norm = tf.nn.l2_normalize(embeddings, 1, 1e-10)
weights_feat_norm = tf.nn.l2_normalize(weights, 0, 1e-10)

xw_norm = tf.matmul(embedds_feat_norm, weights_feat_norm)
margin_xw_norm = xw_norm - alpha

labels = tf.squeeze(tf.cast(labels, tf.int32))
label_onehot = tf.one_hot(labels, num_cls)
value = scale*tf.where(tf.equal(label_onehot, 1), margin_xw_norm, xw_norm)

logits = tf.nn.softmax(value)

correct_prediction = tf.equal(tf.argmax(logits, 1), tf.argmax(label_onehot, 1))
accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))

return accuracy

def large_margin_cos_acc(config):
@functools.wraps(__large_margin_cos_acc)
def acc(labels, embeddings):
return __large_margin_cos_acc(labels, embeddings, config["alpha"], config["scale"])
return acc
78 changes: 0 additions & 78 deletions main.py

This file was deleted.

79 changes: 75 additions & 4 deletions model.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,20 @@
from __future__ import print_function

import os
import yaml
import argparse

import tensorflow as tf
import keras.backend as K

from keras.models import Model
from keras.layers import Conv2D, MaxPooling2D, Input
from keras.regularizers import l2
from keras.layers import Conv2D, MaxPooling2D, AveragePooling2D, Input, GlobalAveragePooling2D, LeakyReLU, SeparableConv2D, BatchNormalization, Add
from keras.layers.core import Dense, Dropout, Flatten, Lambda

def get_model(input_shape, config, top = True):
input_img = Input(input_shape)
num_classes = config["data"]["num_classes"]

def __body(input_img):
x = Conv2D(32, kernel_size=(3, 3), activation='relu')(input_img)
Expand All @@ -17,12 +26,74 @@ def __body(input_img):
return embedding

def __head(embedding):
x = Dropout(0.25)(embedding)
out = Dense(config["data"]["num_classes"], activation='softmax')(x)
x = Dropout(0.5)(embedding)
out = Dense(num_classes, activation='softmax')(x)
return out

x = __body(input_img)
if top: x = __head(x)

model = Model(inputs=input_img, outputs=x)
return model
return model

def simple_resnet(input_shape):

repetitions = [1,1,1]
def add_common_layers(x):
x = BatchNormalization()(x)
x = LeakyReLU()(x)
return x

def residual_block(x, mul = 1, is_shortcut = False):
shortcut = x

x = SeparableConv2D(16 * mul, 1, padding="same", kernel_regularizer=l2(1e-4))(x)
x = add_common_layers(x)

x = SeparableConv2D(16 * mul, 3, padding="same", kernel_regularizer=l2(1e-4))(x)
x = add_common_layers(x)

x = SeparableConv2D(32 * mul, 1, padding="same", kernel_regularizer=l2(1e-4))(x)
x = BatchNormalization()(x)

if is_shortcut:
shortcut = SeparableConv2D(32 * mul, 1, padding='same', kernel_regularizer=l2(1e-4))(shortcut)
shortcut = BatchNormalization()(shortcut)

x = Add()([x, shortcut])
x = LeakyReLU()(x)

return x

input_img = Input(input_shape)
x = Conv2D(32, 7, strides=2, padding="same", kernel_regularizer=l2(1e-4))(input_img)
x = LeakyReLU()(x)
x = MaxPooling2D(2, strides=2, padding="same")(x)

for i, r in enumerate(repetitions):
for j in range(r):
x = residual_block(x, mul = 2**i, is_shortcut = (j==0))
if i < len(repetitions) - 1:
x = MaxPooling2D(2, strides=2, padding="same")(x)
else:
x = AveragePooling2D(2, strides=7)(x)

x = GlobalAveragePooling2D()(x)
model = Model(inputs=input_img, outputs=x)
return model


if __name__ == "__main__":
parser = argparse.ArgumentParser(description='Model Paramaters')
parser.add_argument('-c', '--config', type=str, default="config.yaml", help='path of config file')
args = parser.parse_args()

with open(args.config, 'r') as file:
config = yaml.load(file)

data = config["data"]
input_shape = (data["imsize"], data["imsize"], data["imchannel"])

model = get_model(input_shape, config, top = True)
model.summary()
print("Parameter: {}".format(model.count_params()))
12 changes: 7 additions & 5 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@

from loss import *
from data import DataLoader
from model import get_model
from model import get_model, simple_resnet

def get_loss_function(func):
return {
Expand Down Expand Up @@ -40,18 +40,20 @@ def get_loss_function(func):
with open(os.path.join(paths["save"], config["run-title"] + ".yaml"), 'w') as outfile:
yaml.dump(config, outfile)

dataloader = DataLoader(data, config["train"]["loss"]=="categorical-crossentropy")
dataloader = DataLoader(data, train["loss"]=="categorical-crossentropy")
dataloader.load()

input_shape = (data["imsize"], data["imsize"], data["imchannel"])
model = get_model(input_shape, config, top=config["train"]["loss"]=="categorical-crossentropy")
model = get_model(input_shape, config, top=train["loss"]=="categorical-crossentropy")
# model = simple_resnet(input_shape)

if train["resume"]:
model.load_weights(paths["load"], by_name=True)

loss_func = get_loss_function(config["train"]["loss"])
metric = large_margin_cos_acc(train) if train["loss"]=="large-margin-cosine-loss" else 'acc'
loss_func = get_loss_function(train["loss"])
optim = getattr(optimizers, train["optim"])(train["lr"])
model.compile(loss=loss_func, optimizer=optim, metrics=[])
model.compile(loss=loss_func, optimizer=optim, metrics=[metric])

reduce_lr = ReduceLROnPlateau(monitor='val_loss', factor=train["lr_reduce_factor"], patience=train["patience"], min_lr=train["min_lr"])
checkpoint = ModelCheckpoint(os.path.join(paths["save"],"model.{epoch:02d}-{val_loss:.4f}.h5"), monitor='val_loss', save_best_only=True, mode='min')
Expand Down

0 comments on commit ca881ac

Please sign in to comment.