Skip to content

Commit 9bee743

Browse files
Added Option to Tune on Test Set & Added Comments
1 parent 19f2aaa commit 9bee743

File tree

1 file changed

+29
-9
lines changed

1 file changed

+29
-9
lines changed

misc/tune_linearsvc.py

+29-9
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,10 @@
2929
parser.add_argument("--out", type=str, default="tfidf", choices=["tfidf", "count"],
3030
help="Type of document vectors (default: tfidf)")
3131

32+
parser.add_argument("--test", action="store_false", dest="validation",
33+
help="Include this flag if models should be tuned on the test set instead.")
34+
parser.set_defaults(validation=True)
35+
3236
parser.add_argument("--min", type=float, default=0)
3337
parser.add_argument("--max", type=float, default=5)
3438

@@ -45,10 +49,13 @@
4549
y_train = train.labels
4650
y_test = test.labels
4751

48-
# Split training set & validation set
49-
validation_index = -1 * int(0.1 * float(len(y_train)))
50-
x_train, x_valid = x_train[:validation_index], x_train[validation_index:]
51-
y_train, y_valid = y_train[:validation_index], y_train[validation_index:]
52+
if args.validation:
53+
# Split training set & validation set
54+
validation_index = -1 * int(0.1 * float(len(y_train)))
55+
x_train, x_valid = x_train[:validation_index], x_train[validation_index:]
56+
y_train, y_valid = y_train[:validation_index], y_train[validation_index:]
57+
else:
58+
x_valid, y_valid = [], []
5259

5360
# Print information about the dataset
5461
print("")
@@ -69,29 +76,42 @@
6976
# Training
7077
# ==================================================
7178

72-
acc_dict = {}
79+
# Generate C values to test [min, max, 0.1]
7380
C_arr = [float('%.1f' % i) for i in np.arange(args.min, args.max + 0.1, 0.1)]
81+
82+
# Train & test models with different hyperparameter values
83+
acc_dict = {}
7484
for i in C_arr:
7585
if i <= 0:
7686
continue
7787
svm_clf = LinearSVC(C=i)
7888
svm_clf.fit(x_train, y_train)
79-
predicted = svm_clf.predict(x_valid)
80-
svm_acc = np.mean(predicted == y_valid)
89+
if args.validation:
90+
predicted = svm_clf.predict(x_valid)
91+
svm_acc = np.mean(predicted == y_valid)
92+
else:
93+
predicted = svm_clf.predict(x_test)
94+
svm_acc = np.mean(predicted == y_test)
8195
acc_dict[i] = svm_acc
8296
print("C {:.2f}: {:g}".format(i, svm_acc))
8397

8498
print(acc_dict)
8599
print("")
86100

87-
x_train = vstack((x_train, x_valid))
88-
y_train = np.concatenate((y_train, y_valid), axis=0)
101+
# Concatenate training set & validation set to form original train set
102+
if args.validation:
103+
x_train = vstack((x_train, x_valid))
104+
y_train = np.concatenate((y_train, y_valid), axis=0)
105+
106+
# Get optimized hyperparameter
89107
max_C = max(acc_dict.keys(), key=(lambda key: acc_dict[key]))
90108

109+
# Re-train & test model with chosen hyperparameter
91110
svm_clf = LinearSVC(C=max_C)
92111
svm_clf.fit(x_train, y_train)
93112
predicted = svm_clf.predict(x_test)
94113
svm_acc = np.mean(predicted == y_test)
95114

115+
# Print result of final model
96116
utils.print_result(args.dataset, "linear_svc", svm_acc, data_str, str(int(time.time())),
97117
hyperparams="{{C: {}}}".format(max_C))

0 commit comments

Comments
 (0)