-
Notifications
You must be signed in to change notification settings - Fork 0
/
run_script_7.py
35 lines (27 loc) · 904 Bytes
/
run_script_7.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
import tensorflow as tf
import numpy as np
import Functions as fn
import importlib
importlib.reload(fn)
# testdatanorm = np.random.rand(100, 16, 100)
# testlabelsnorm = np.random.rand(100, 100)
checkpointing_config = tf.estimator.RunConfig(
save_checkpoints_secs=20 * 60, # Save checkpoints every 20 minutes.
keep_checkpoint_max=2, # Retain the 10 most recent checkpoints.
save_summary_steps=1000,
log_step_count_steps=1000
)
classifier = tf.estimator.Estimator(
model_fn=fn.CNNmodel,
model_dir='CNN_multi_pulse',
config=checkpointing_config,
params={
# 'feature_columns': the_feature_column,
# Layers.
'NUM_COOKIES': 16,
'CNN': [[12, 10]], # Convolutional layers
'POOL': [10, 5], # Global Pooling Label
'DENSE': [2304, 1152, 1152, 500, 500, 200], # Dense layers
'OUT': 200 # output dimensions
}
)