|
5 | 5 | # set log level to debug
|
6 | 6 | tf.sg_verbosity(10)
|
7 | 7 |
|
| 8 | +class Hyperparams: |
| 9 | + batch_size = 64 |
| 10 | + |
8 | 11 | def load_data(is_train=True):
|
9 |
| - Y = np.load('data/sudoku.npy') # solutions |
| 12 | + '''Loads training / validation data. |
10 | 13 |
|
11 |
| - X = np.zeros_like(Y, dtype=np.float32) |
12 |
| - for i, y in enumerate(Y): # game-wise |
13 |
| - nblanks = np.random.randint(1, 65) # We generate a problem which varies from 1 to 65 in number of blanks. |
14 |
| - blank_indices = np.random.choice(81, nblanks) |
15 |
| - masks= np.ones((9*9)) |
16 |
| - masks[blank_indices] = 0 |
17 |
| - masks = masks.reshape((9, 9)) |
18 |
| - |
19 |
| - x = y * masks # puzzle. 0: blanks=targets. |
20 |
| - X[i] = x |
| 14 | + Args |
| 15 | + is_train: Boolean. If True, it loads training data. |
| 16 | + Otherwise, it loads validation data. |
| 17 | + |
| 18 | + Returns: |
| 19 | + X: 4-D array of float. Has the shape of (# total games, 9, 9, 1) (for train) |
| 20 | + or (batch_size, 9, 9, 1) (for validation) |
| 21 | + Y: 3-D array of int. Has the shape of (# total games, 9, 9) (for train) |
| 22 | + or (batch_size, 9, 9) (for validation) |
| 23 | + ''' |
| 24 | + X = np.load('data/sudoku.npz')['quizzes'].astype(np.float32) |
| 25 | + Y = np.load('data/sudoku.npz')['solutions'] |
21 | 26 |
|
22 | 27 | X = np.expand_dims(X, -1)
|
23 | 28 |
|
24 | 29 | if is_train:
|
25 |
| - return X[:-100], Y[:-100] # training data |
| 30 | + return X[:-Hyperparams.batch_size], Y[:-Hyperparams.batch_size] # training data |
26 | 31 | else:
|
27 |
| - return X[-100:], Y[-100:] # validation data |
| 32 | + return X[-Hyperparams.batch_size:], Y[-Hyperparams.batch_size:] # validation data |
| 33 | + |
| 34 | +def get_batch_data(is_train=True): |
| 35 | + '''Returns batch data. |
28 | 36 |
|
29 |
| -def get_batch_data(is_train=True, batch_size=16): |
30 |
| - ''' |
31 | 37 | Args:
|
32 |
| - is_train: Boolean. If True, load training data. Otherwise, load validation data. |
| 38 | + is_train: Boolean. If True, it returns batch training data. |
| 39 | + Otherwise, batch validation data. |
| 40 | + |
33 | 41 | Returns:
|
34 |
| - A Tuple of X batch queues (Tensor), Y batch queues (Tensor), and number of batches (int) |
| 42 | + A Tuple of x, y, and num_batch |
| 43 | + x: A `Tensor` of float. Has the shape of (batch_size, 9, 9, 1). |
| 44 | + y: A `Tensor` of int. Has the shape of (batch_size, 9, 9). |
| 45 | + num_batch = A Python int. Number of batches. |
35 | 46 | '''
|
36 |
| - # Load data |
37 | 47 | X, Y = load_data(is_train=is_train)
|
38 | 48 |
|
39 | 49 | # Create Queues
|
40 | 50 | input_queues = tf.train.slice_input_producer([tf.convert_to_tensor(X),
|
41 | 51 | tf.convert_to_tensor(Y)])
|
42 | 52 |
|
43 | 53 | # create batch queues
|
44 |
| - X_batch, Y_batch = tf.train.shuffle_batch(input_queues, |
45 |
| - num_threads=8, |
46 |
| - batch_size=batch_size, |
47 |
| - capacity=batch_size*64, |
48 |
| - min_after_dequeue=batch_size*32, |
49 |
| - allow_smaller_final_batch=False) |
| 54 | + x, y = tf.train.shuffle_batch(input_queues, |
| 55 | + num_threads=8, |
| 56 | + batch_size=Hyperparams.batch_size, |
| 57 | + capacity=Hyperparams.batch_size*64, |
| 58 | + min_after_dequeue=Hyperparams.batch_size*32, |
| 59 | + allow_smaller_final_batch=False) |
50 | 60 | # calc total batch count
|
51 | 61 | num_batch = len(X) // batch_size
|
52 | 62 |
|
53 |
| - return X_batch, Y_batch, num_batch # (16, 9, 9, 1) int32. cf. Y_batch: (16, 9, 9) int32 |
| 63 | + return x, y, num_batch # (64, 9, 9, 1), (64, 9, 9), () |
54 | 64 |
|
55 | 65 | class Graph(object):
|
56 | 66 | def __init__(self, is_train=True):
|
57 | 67 | # inputs
|
58 | 68 | if is_train:
|
59 |
| - self.X, self.Y, self.num_batch = get_batch_data() # (16, 9, 9, 1), (16, 9, 9) |
60 |
| - self.X_val, self.Y_val, _ = get_batch_data(is_train=False) |
| 69 | + self.x, self.y, self.num_batch = get_batch_data() |
| 70 | + self.x_val, self.y_val, _ = get_batch_data(is_train=False) |
61 | 71 | else:
|
62 |
| - self.X = tf.placeholder(tf.float32, [None, 9, 9, 1]) |
| 72 | + self.x = tf.placeholder(tf.float32, [None, 9, 9, 1]) |
63 | 73 |
|
64 | 74 | with tf.sg_context(size=3, act='relu', bn=True):
|
65 |
| - self.logits = self.X.sg_identity() |
66 |
| - for _ in range(5): |
| 75 | + self.logits = self.x.sg_identity() |
| 76 | + for _ in range(10): |
67 | 77 | self.logits = (self.logits.sg_conv(dim=512))
|
68 |
| - self.logits = self.logits.sg_conv(dim=10, size=1, act='linear', bn=False) # (16, 9, 9, 10) float32 |
| 78 | + |
| 79 | + self.logits = self.logits.sg_conv(dim=10, size=1, act='linear', bn=False) |
69 | 80 |
|
70 | 81 | if is_train:
|
71 |
| - self.ce = self.logits.sg_ce(target=self.Y, mask=False) # (16, 9, 9) dtype=float32 |
72 |
| - self.istarget = tf.equal(self.X.sg_squeeze(), tf.zeros_like(self.X.sg_squeeze())).sg_float() # zeros: 1, non-zeros: 0 (16, 9, 9) dtype=float32 |
73 |
| - self.loss = self.ce * self.istarget # (16, 9, 9) dtype=float32 |
| 82 | + self.ce = self.logits.sg_ce(target=self.y, mask=False) |
| 83 | + self.istarget = tf.equal(self.x.sg_squeeze(), tf.zeros_like(self.x.sg_squeeze())).sg_float() |
| 84 | + self.loss = self.ce * self.istarget |
74 | 85 | self.reduced_loss = self.loss.sg_sum() / self.istarget.sg_sum()
|
75 | 86 | tf.sg_summary_loss(self.reduced_loss, "reduced_loss")
|
76 | 87 |
|
77 |
| - # accuracy evaluation ( for train set ) |
78 |
| - self.preds = (self.logits.sg_argmax()).sg_int() |
79 |
| - self.hits = tf.equal(self.preds, self.Y).sg_float() |
80 |
| - self.acc_train = (self.hits * self.istarget).sg_sum() / self.istarget.sg_sum() |
81 |
| - |
82 | 88 | # accuracy evaluation ( for validation set )
|
83 |
| - self.preds_ = (self.logits.sg_reuse(input=self.X_val).sg_argmax()).sg_int() |
84 |
| - self.hits_ = tf.equal(self.preds_, self.Y_val).sg_float() |
85 |
| - self.istarget_ = tf.equal(self.X_val.sg_squeeze(), tf.zeros_like(self.X_val.sg_squeeze())).sg_float() |
86 |
| - self.acc_val = (self.hits_ * self.istarget_).sg_sum() / self.istarget_.sg_sum() |
| 89 | + self.preds_ = (self.logits.sg_reuse(input=self.x_val).sg_argmax()).sg_int() |
| 90 | + self.hits_ = tf.equal(self.preds_, self.y_val).sg_float() |
| 91 | + self.istarget_ = tf.equal(self.x_val.sg_squeeze(), tf.zeros_like(self.x_val.sg_squeeze())).sg_float() |
| 92 | + self.acc = (self.hits_ * self.istarget_).sg_sum() / self.istarget_.sg_sum() |
87 | 93 |
|
88 | 94 | def main():
|
89 | 95 | g = Graph()
|
90 | 96 |
|
91 |
| - tf.sg_train(log_interval=10, loss=g.reduced_loss, eval_metric=[g.acc_train, g.acc_val], |
| 97 | + tf.sg_train(lr=0.0001, lr_reset=True, log_interval=10, save_interval=300, |
| 98 | + loss=g.reduced_loss, eval_metric=[g.acc], |
92 | 99 | ep_size=g.num_batch, save_dir='asset/train', max_ep=10, early_stop=False)
|
93 | 100 |
|
94 | 101 | if __name__ == "__main__":
|
|
0 commit comments