Skip to content

Commit

Permalink
added intra-enhanced tripet loss
Browse files Browse the repository at this point in the history
  • Loading branch information
BKHMSI committed Oct 8, 2018
1 parent fb7a253 commit 0bbc721
Show file tree
Hide file tree
Showing 6 changed files with 72 additions and 24 deletions.
3 changes: 2 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
*.pyc
Models
Models
Graph
10 changes: 5 additions & 5 deletions config.yaml
Original file line number Diff line number Diff line change
@@ -1,17 +1,17 @@
run-title: "TripletMnist"
run-title: "SoftmaxMnist"

paths:
save: "Models/2"
load: "Models/2/model.14-0.0065.h5"
save: "Models/1"
load: ""

train:
lr: 0.001
optim: "Nadam"

epochs: 100
batch-size: 400
batch-size: 250

loss: "semi-hard-triplet-loss"
loss: "categorical-crossentropy"
alpha: 0.2

lr_reduce_factor: 0.5
Expand Down
16 changes: 9 additions & 7 deletions data.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,9 @@
from keras.datasets import mnist

class DataLoader(object):
def __init__(self, config):
def __init__(self, config, one_hot = False):
self.config = config
self.one_hot = one_hot

def load(self):
(X_train, self.y_train), (X_test, self.y_test) = mnist.load_data()
Expand All @@ -21,8 +22,9 @@ def load(self):
self.X_train = self.preprocess(X_train)
self.X_test = self.preprocess(X_test)

self.Y_train = np_utils.to_categorical(self.y_train, self.config["num_classes"])
self.Y_test = np_utils.to_categorical(self.y_test, self.config["num_classes"])
if self.one_hot:
self.y_train = np_utils.to_categorical(self.y_train, self.config["num_classes"])
self.y_test = np_utils.to_categorical(self.y_test, self.config["num_classes"])

self.num_train = int(self.y_train.shape[0] * (1-self.config["val_split"]))
self.num_val = int(self.y_train.shape[0] * (self.config["val_split"]))
Expand All @@ -38,9 +40,9 @@ def preprocess(self, data):
def get_random_batch(self, k = 100):
X_batch, y_batch = [], []
for label in range(self.config["num_classes"]):
mask = self.y_test==label
X_batch += [self.X_test[mask][np.random.choice(np.sum(mask), k, replace=False)]]
y_batch += [label] * k
X_batch = np.reshape(np.array(X_batch), self.input_shape)
X_mask = self.X_test[self.y_test==label]
X_batch.extend(np.array([X_mask[np.random.choice(len(X_mask), k, replace=False)]]) if k <= len(X_mask) and k >= 0 else X_mask)
y_batch += [label] * k if k <= len(X_mask) and k >= 0 else [label] * len(X_mask)
X_batch = np.reshape(X_batch, self.input_shape)
return X_batch, np.array(y_batch)

41 changes: 41 additions & 0 deletions loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,52 @@
import tensorflow as tf
from keras import backend as K

def __anchor_center_loss(embeddings, margin, batch_size = 240, k = 4):
"""Computes the anchor-center loss
Minimizes intra-class distances. Assumes embeddings are ordered
such that every k samples belong to the same class, where the
number of classes is batch_size // k.
Args:
embeddings: tensor of shape (batch_size, embed_dim)
margin: intra-class distances should be within this margin
batch_size: number of embeddings
k: number of samples per class in embeddings
Returns:
loss: scalar tensor containing the anchor-center loss
"""
loss = tf.constant(0, dtype='float32')
for i in range(0,batch_size,k):
anchors = embeddings[i:i+k]
center = tf.reduce_mean(anchors, 0)
loss = tf.add(loss, tf.reduce_sum(tf.maximum(tf.reduce_sum(tf.square(anchors - center), axis=1) - margin, 0.)))
return tf.reduce_mean(loss)

def __semi_hard_triplet_loss(labels, embeddings, margin = 0.2):
return tf.contrib.losses.metric_learning.triplet_semihard_loss(labels, embeddings, margin=margin)

def __intra_enhanced_triplet_loss(labels, embeddings, lambda_1, alpha, beta, batch_size, k):
return tf.add(__semi_hard_triplet_loss(labels, embeddings, alpha), tf.multiply(lambda_1, __anchor_center_loss(embeddings, beta, batch_size, k)))

def __large_margin_cos_loss(labels, embeddings):
loss = tf.constant(0, dtype='float32')
return loss

def semi_hard_triplet_loss(margin):
@functools.wraps(__semi_hard_triplet_loss)
def loss(labels, embeddings):
return __semi_hard_triplet_loss(labels, embeddings, margin)
return loss

def intra_enhanced_triplet_loss(train, data):
@functools.wraps(__intra_enhanced_triplet_loss)
def loss(labels, embeddings):
return __intra_enhanced_triplet_loss(labels, embeddings, train["lambda_1"], train["alpha"], train["beta"], train["batch-size"], data["k_batch"])
return loss

def large_margin_cos_loss(config):
@functools.wraps(__large_margin_cos_loss)
def loss(labels, embeddings):
return __large_margin_cos_loss(labels, embeddings)
return loss
23 changes: 13 additions & 10 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,20 +5,22 @@
import argparse
import numpy as np
import tensorflow as tf
import keras.optimizers
import keras.optimizers as optimizers

from keras import losses
from keras.callbacks import ModelCheckpoint, ReduceLROnPlateau
from keras.callbacks import ModelCheckpoint, ReduceLROnPlateau, TensorBoard

from loss import *
from data import DataLoader
from model import get_model

def get_loss_function(config):
def get_loss_function(func):
return {
'large-margin-cosine-loss': None,
'large-margin-cosine-loss': large_margin_cos_loss(config["train"]),
'intra-enhanced-triplet-loss': intra_enhanced_triplet_loss(config["train"], config["data"]),
'semi-hard-triplet-loss': semi_hard_triplet_loss(config["train"]["alpha"]),
'categorical-crossentropy': losses.categorical_crossentropy,
}.get(config["train"]["loss"], losses.categorical_crossentropy)
}.get(func, losses.categorical_crossentropy)

if __name__ == "__main__":

Expand All @@ -38,7 +40,7 @@ def get_loss_function(config):
with open(os.path.join(paths["save"], config["run-title"] + ".yaml"), 'w') as outfile:
yaml.dump(config, outfile)

dataloader = DataLoader(data)
dataloader = DataLoader(data, config["train"]["loss"]=="categorical-crossentropy")
dataloader.load()

input_shape = (data["imsize"], data["imsize"], data["imchannel"])
Expand All @@ -47,18 +49,19 @@ def get_loss_function(config):
if train["resume"]:
model.load_weights(paths["load"], by_name=True)

loss_func = get_loss_function(config)

optim = getattr(keras.optimizers, train["optim"])(train["lr"])
loss_func = get_loss_function(config["train"]["loss"])
optim = getattr(optimizers, train["optim"])(train["lr"])
model.compile(loss=loss_func, optimizer=optim, metrics=[])

reduce_lr = ReduceLROnPlateau(monitor='val_loss', factor=train["lr_reduce_factor"], patience=train["patience"], min_lr=train["min_lr"])
checkpoint = ModelCheckpoint(os.path.join(paths["save"],"model.{epoch:02d}-{val_loss:.4f}.h5"), monitor='val_loss', save_best_only=True, mode='min')
tensorboard = TensorBoard(log_dir=os.path.join('./Graph',config["run-title"]), histogram_freq=0, write_graph=True, write_images=True)

model.fit(dataloader.X_train, dataloader.y_train,
epochs=train["epochs"],
batch_size=train["batch-size"],
verbose=1,
shuffle=True,
validation_split=data["val_split"],
callbacks=[checkpoint, reduce_lr]
callbacks=[checkpoint, reduce_lr, tensorboard]
)
3 changes: 2 additions & 1 deletion visualize.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,8 +46,9 @@ def scatter(x, labels, config):

model.load_weights(paths["load"], by_name=True)

X_batch, y_batch = dataloader.get_random_batch(k = 500)
X_batch, y_batch = dataloader.get_random_batch(k = -1)

#embeddings = X_batch.reshape(-1, 784)
embeddings = model.predict(X_batch, batch_size=config["train"]["batch-size"], verbose=1)

tsne = TSNE(n_components=2, perplexity=30, verbose=1)
Expand Down

0 comments on commit 0bbc721

Please sign in to comment.