Skip to content

Commit c2b1e49

Browse files
Added Option for Test Set
1 parent 9bee743 commit c2b1e49

File tree

1 file changed

+14
-5
lines changed

1 file changed

+14
-5
lines changed

misc/grid_search_gcnn.py

+14-5
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,10 @@
2929
parser.add_argument("--vocab_size", type=int, default=None,
3030
help="Vocabulary size (default: None [see data.py])")
3131

32+
parser.add_argument("--test", action="store_false", dest="validation",
33+
help="Include this flag if models should be tuned on the test set instead.")
34+
parser.set_defaults(validation=True)
35+
3236
args = parser.parse_args()
3337

3438

@@ -119,15 +123,20 @@ def run_experiment(x_train, y_train, x_valid, y_valid, embeddings, _num_edges, _
119123
# ==================================================
120124

121125
train, test = data.load_dataset(args.dataset, out="tfidf", vocab_size=10000)
122-
del test
123126

124127
x_train = train.data.astype(np.float32)
125128
y_train = train.labels
126129

127-
# Split training set & validation set
128-
validation_index = -1 * int(0.1 * float(len(y_train)))
129-
x_train, x_valid = x_train[:validation_index], x_train[validation_index:]
130-
y_train, y_valid = y_train[:validation_index], y_train[validation_index:]
130+
if args.validation:
131+
del test # don't need this anymore
132+
133+
# Split training set & validation set
134+
validation_index = -1 * int(0.1 * float(len(y_train)))
135+
x_train, x_valid = x_train[:validation_index], x_train[validation_index:]
136+
y_train, y_valid = y_train[:validation_index], y_train[validation_index:]
137+
else:
138+
x_valid = test.data.astype(np.float32)
139+
y_valid = test.labels
131140

132141
# Construct reverse lookup vocabulary
133142
reverse_vocab = {w: i for i, w in enumerate(train.vocab)}

0 commit comments

Comments
 (0)