1616import model .plot as plot
1717from architecture .single_model import img_alexnet_layers
1818from evaluation import MAPs
19- from .util import Dataset
2019
2120
22- class PruneHash (object ):
23- def __init__ (self , config , stage ):
21+ class DCH (object ):
22+ def __init__ (self , config ):
2423 ### Initialize setting
2524 print ("initializing" )
2625 np .set_printoptions (precision = 4 )
27- self .stage = stage
28- self .device = config ['device' ]
29- self .output_dim = config ['output_dim' ]
30- self .n_class = config ['label_dim' ]
31- self .cq_lambda = config ['cq_lambda' ]
32- self .alpha = config ['alpha' ]
33- self .bias = config ['bias' ]
34- self .gamma = config ['gamma' ]
35-
36- self .batch_size = config ['batch_size' ] if self .stage == "train" else config ['val_batch_size' ]
37- self .max_iter = config ['max_iter' ]
38- self .img_model = config ['img_model' ]
39- self .loss_type = config ['loss_type' ]
40- self .learning_rate = config ['learning_rate' ]
41- self .learning_rate_decay_factor = config ['learning_rate_decay_factor' ]
42- self .decay_step = config ['decay_step' ]
43-
44- self .finetune_all = config ['finetune_all' ]
4526
27+ with tf .name_scope ('stage' ):
28+ # 0 for training, 1 for validation
29+ self .stage = tf .placeholder_with_default (tf .constant (0 ), [])
30+ for k , v in vars (config ).items ():
31+ setattr (self , k , v )
4632 self .file_name = 'loss_{}_lr_{}_cqlambda_{}_alpha_{}_bias_{}_gamma_{}_dataset_{}' .format (
4733 self .loss_type ,
48- self .learning_rate ,
49- self .cq_lambda ,
34+ self .lr ,
35+ self .q_lambda ,
5036 self .alpha ,
5137 self .bias ,
5238 self .gamma ,
53- config ['dataset' ])
54- self .save_dir = config ['save_dir' ]
55- self .save_file = os .path .join (config ['save_dir' ], self .file_name + '.npy' )
56- self .log_dir = config ['log_dir' ]
39+ self .dataset )
40+ self .save_file = os .path .join (self .save_dir , self .file_name + '.npy' )
5741
5842 ### Setup session
5943 print ("launching session" )
@@ -63,27 +47,25 @@ def __init__(self, config, stage):
6347 self .sess = tf .Session (config = configProto )
6448
6549 ### Create variables and placeholders
50+ self .img = tf .placeholder (tf .float32 , [None , 256 , 256 , 3 ])
51+ self .img_label = tf .placeholder (tf .float32 , [None , self .label_dim ])
52+ self .img_last_layer , self .deep_param_img , self .train_layers , self .train_last_layer = self .load_model ()
6653
67- with tf .device (self .device ):
68- self .img = tf .placeholder (tf .float32 , [self .batch_size , 256 , 256 , 3 ])
69- self .img_label = tf .placeholder (tf .float32 , [self .batch_size , self .n_class ])
70-
71- if self .stage == 'train' :
72- self .model_weights = config ['model_weights' ]
73- else :
74- self .model_weights = self .save_file
75- self .img_last_layer , self .deep_param_img , self .train_layers , self .train_last_layer = self .load_model ()
76-
77- self .global_step = tf .Variable (0 , trainable = False )
78- self .train_op = self .apply_loss_function (self .global_step )
79- self .sess .run (tf .global_variables_initializer ())
54+ self .global_step = tf .Variable (0 , trainable = False )
55+ self .train_op = self .apply_loss_function (self .global_step )
56+ self .sess .run (tf .global_variables_initializer ())
8057 return
8158
8259 def load_model (self ):
8360 if self .img_model == 'alexnet' :
8461 img_output = img_alexnet_layers (
85- self .img , self .batch_size , self .output_dim ,
86- self .stage , self .model_weights )
62+ self .img ,
63+ self .batch_size ,
64+ self .output_dim ,
65+ self .stage ,
66+ self .model_weights ,
67+ self .with_tanh ,
68+ self .val_batch_size )
8769 else :
8870 raise Exception ('cannot use such CNN model as ' + self .img_model )
8971 return img_output
@@ -139,7 +121,7 @@ def reduce_shaper(t):
139121 r = tf .reshape (r , [- 1 , 1 ])
140122 ip = r - 2 * tf .matmul (u , tf .transpose (u )) + tf .transpose (r )
141123
142- ip = tf . constant ( self .gamma ) / (ip + tf . constant ( self .gamma ) * tf . constant ( self . gamma ) )
124+ ip = self .gamma / (ip + self .gamma ** 2 )
143125 else :
144126 ip = tf .clip_by_value (tf .matmul (u , tf .transpose (u )), - 1.5e1 , 1.5e1 )
145127 ones = tf .ones ([tf .shape (u )[0 ], tf .shape (u )[0 ]])
@@ -158,13 +140,12 @@ def apply_loss_function(self, global_step):
158140 self .cos_loss = self .cross_entropy (self .img_last_layer , self .img_label , self .alpha , True , True , self .bias )
159141
160142 self .q_loss_img = tf .reduce_mean (tf .square (tf .subtract (tf .abs (self .img_last_layer ), tf .constant (1.0 ))))
161- self .q_lambda = tf .Variable (self .cq_lambda , name = 'cq_lambda' )
162- self .q_loss = tf .multiply (self .q_lambda , self .q_loss_img )
143+ self .q_loss = self .q_lambda * self .q_loss_img
163144 self .loss = self .cos_loss + self .q_loss
164145
165146 ### Last layer has a 10 times learning rate
166- self . lr = tf .train .exponential_decay (self .learning_rate , global_step , self .decay_step , self .learning_rate_decay_factor , staircase = True )
167- opt = tf .train .MomentumOptimizer (learning_rate = self . lr , momentum = 0.9 )
147+ lr = tf .train .exponential_decay (self .lr , global_step , self .decay_step , self .lr , staircase = True )
148+ opt = tf .train .MomentumOptimizer (learning_rate = lr , momentum = 0.9 )
168149 grads_and_vars = opt .compute_gradients (self .loss , self .train_layers + self .train_last_layer )
169150 fcgrad , _ = grads_and_vars [- 2 ]
170151 fbgrad , _ = grads_and_vars [- 1 ]
@@ -174,11 +155,11 @@ def apply_loss_function(self, global_step):
174155 tf .summary .scalar ('loss' , self .loss )
175156 tf .summary .scalar ('cos_loss' , self .cos_loss )
176157 tf .summary .scalar ('q_loss' , self .q_loss )
177- tf .summary .scalar ('lr' , self . lr )
158+ tf .summary .scalar ('lr' , lr )
178159 self .merged = tf .summary .merge_all ()
179160
180161
181- if self .stage == "train" and self . finetune_all :
162+ if self .finetune_all :
182163 return opt .apply_gradients ([(grads_and_vars [0 ][0 ], self .train_layers [0 ]),
183164 (grads_and_vars [1 ][0 ]* 2 , self .train_layers [1 ]),
184165 (grads_and_vars [2 ][0 ], self .train_layers [2 ]),
@@ -208,13 +189,10 @@ def train(self, img_dataset):
208189 shutil .rmtree (tflog_path )
209190 train_writer = tf .summary .FileWriter (tflog_path , self .sess .graph )
210191
211- for train_iter in range (self .max_iter ):
192+ for train_iter in range (self .iter_num ):
212193 images , labels = img_dataset .next_batch (self .batch_size )
213194 start_time = time .time ()
214195
215- assign_lambda = self .q_lambda .assign (self .cq_lambda )
216- self .sess .run ([assign_lambda ])
217-
218196 _ , loss , cos_loss , output , summary = self .sess .run ([self .train_op , self .loss , self .cos_loss , self .img_last_layer , self .merged ],
219197 feed_dict = {self .img : images ,
220198 self .img_label : labels })
@@ -224,7 +202,7 @@ def train(self, img_dataset):
224202 img_dataset .feed_batch_output (self .batch_size , output )
225203 duration = time .time () - start_time
226204
227- if train_iter % 1 == 0 :
205+ if train_iter % 100 == 0 :
228206 print ("%s #train# step %4d, loss = %.4f, cross_entropy loss = %.4f, %.1f sec/batch"
229207 % (datetime .now (), train_iter + 1 , loss , cos_loss , duration ))
230208
@@ -236,24 +214,29 @@ def train(self, img_dataset):
236214
237215 def validation (self , img_query , img_database , R = 100 ):
238216 print ("%s #validation# start validation" % (datetime .now ()))
239- query_batch = int (ceil (img_query .n_samples / self .batch_size ))
217+ query_batch = int (ceil (img_query .n_samples / float (self .val_batch_size )))
218+ img_query .finish_epoch ()
240219 print ("%s #validation# totally %d query in %d batches" % (datetime .now (), img_query .n_samples , query_batch ))
241220 for i in range (query_batch ):
242- images , labels = img_query .next_batch (self .batch_size )
221+ images , labels = img_query .next_batch (self .val_batch_size )
243222 output , loss = self .sess .run ([self .img_last_layer , self .cos_loss ],
244- feed_dict = {self .img : images , self .img_label : labels })
245- img_query .feed_batch_output (self .batch_size , output )
223+ feed_dict = {self .img : images ,
224+ self .img_label : labels ,
225+ self .stage : 1 })
226+ img_query .feed_batch_output (self .val_batch_size , output )
246227 print ('Cosine Loss: %s' % loss )
247228
248- database_batch = int (ceil (img_database .n_samples / self .batch_size ))
229+ database_batch = int (ceil (img_database .n_samples / float (self .val_batch_size )))
230+ img_database .finish_epoch ()
249231 print ("%s #validation# totally %d database in %d batches" % (datetime .now (), img_database .n_samples , database_batch ))
250232 for i in range (database_batch ):
251- images , labels = img_database .next_batch (self .batch_size )
233+ images , labels = img_database .next_batch (self .val_batch_size )
252234
253235 output , loss = self .sess .run ([self .img_last_layer , self .cos_loss ],
254- feed_dict = {self .img : images , self .img_label : labels })
255- img_database .feed_batch_output (self .batch_size , output )
256- #print output[:10, :10]
236+ feed_dict = {self .img : images ,
237+ self .img_label : labels ,
238+ self .stage : 1 })
239+ img_database .feed_batch_output (self .val_batch_size , output )
257240 if i % 100 == 0 :
258241 print ('Cosine Loss[%d/%d]: %s' % (i , database_batch , loss ))
259242
@@ -283,15 +266,3 @@ def validation(self, img_query, img_database, R=100):
283266 'i2i_map_radius_2' : mmap ,
284267 }
285268
286- def train (train_img , config ):
287- model = PruneHash (config , 'train' )
288- img_dataset = Dataset (train_img , config ['output_dim' ])
289- model .train (img_dataset )
290- return model .save_file
291-
292- def validation (database_img , query_img , config ):
293- model = PruneHash (config , 'val' )
294- img_database = Dataset (database_img , config ['output_dim' ])
295- img_query = Dataset (query_img , config ['output_dim' ])
296- return model .validation (img_query , img_database , config ['R' ])
297-
0 commit comments