Skip to content

Commit 4c5459c

Browse files
author
Winter Deng
committed
classify flags to 'general', 'linear', 'nn' categories, and change some flags orders in main.py
1 parent c033a52 commit 4c5459c

File tree

3 files changed

+242
-68
lines changed

3 files changed

+242
-68
lines changed

docs/cli/classifier.py

Lines changed: 137 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,137 @@
1+
import os
2+
import sys
3+
import glob
4+
import re
5+
from pathlib import Path
6+
from collections import defaultdict
7+
8+
current_dir = os.path.dirname(os.path.abspath(__file__))
9+
lib_path = os.path.abspath(os.path.join(current_dir, "..", ".."))
10+
sys.path.insert(0, lib_path)
11+
12+
def classify_file_category(path):
13+
14+
relative_path = Path(path).relative_to(lib_path)
15+
return_path = relative_path.as_posix()
16+
filename = Path(*relative_path.parts[1:]).as_posix() if len(relative_path.parts) > 1 else return_path
17+
18+
if filename.startswith("linear"):
19+
category = "linear"
20+
elif filename.startswith("torch") or filename.startswith("nn"):
21+
category = "nn"
22+
else:
23+
category = "general"
24+
return category, return_path
25+
26+
27+
def fetch_option_flags(flags):
28+
# flags = genflags.parser.flags
29+
flag_list = []
30+
31+
for flag in flags:
32+
flag_list.append(
33+
{
34+
"name": flag["name"].replace("\\", ""),
35+
"instruction": flag["name"].split("-")[-1],
36+
"description": flag["description"]
37+
}
38+
)
39+
40+
return flag_list
41+
42+
43+
def fetch_all_files():
44+
main_files = [
45+
os.path.join(lib_path, "linear_trainer.py"),
46+
os.path.join(lib_path, "torch_trainer.py")
47+
]
48+
lib_files = glob.glob(os.path.join(lib_path, "libmultilabel/**/*.py"), recursive=True)
49+
file_set = set(map(os.path.abspath, main_files + lib_files))
50+
return file_set
51+
52+
53+
def find_config_usages_in_file(file_path, allowed_keys):
54+
pattern = re.compile(r'\bconfig\.([a-zA-Z_][a-zA-Z0-9_]*)')
55+
detailed_results = {}
56+
try:
57+
with open(file_path, "r", encoding="utf-8") as f:
58+
lines = f.readlines()
59+
except (IOError, UnicodeDecodeError):
60+
return []
61+
62+
category, path = classify_file_category(file_path)
63+
64+
for i, line in enumerate(lines, start=1):
65+
matches = pattern.findall(line)
66+
for key in matches:
67+
if key in allowed_keys:
68+
if key not in detailed_results:
69+
detailed_results[key] = {"file": path, "lines": []}
70+
detailed_results[key]["lines"].append(str(i))
71+
72+
return detailed_results
73+
74+
75+
def move_duplicates_together(data, keep):
76+
all_keys = list(data.keys())
77+
duplicates = set()
78+
79+
for i, key1 in enumerate(all_keys):
80+
for key2 in all_keys[i+1:]:
81+
duplicates |= data[key1] & data[key2]
82+
83+
data[keep] |= duplicates
84+
85+
for key in all_keys:
86+
if key != keep:
87+
data[key] -= duplicates
88+
89+
return data
90+
91+
92+
def classify(raw_flags):
93+
94+
category_set = {"general": set(), "linear": set(), "nn": set()}
95+
flags = fetch_option_flags(raw_flags)
96+
allowed_keys = set(flag["instruction"] for flag in flags)
97+
file_set = fetch_all_files()
98+
usage_map = defaultdict(list)
99+
collected = {}
100+
101+
for file_path in file_set:
102+
detailed_results = find_config_usages_in_file(file_path, allowed_keys)
103+
if detailed_results:
104+
usage_map[file_path] = set(detailed_results.keys())
105+
for k, v in detailed_results.items():
106+
if k not in collected:
107+
collected[k] = []
108+
collected[k].append(v)
109+
110+
for path, keys in usage_map.items():
111+
category, path = classify_file_category(path)
112+
category_set[category] = category_set[category].union(keys)
113+
114+
category_set = move_duplicates_together(category_set, "general")
115+
116+
for flag in flags:
117+
for k, v in category_set.items():
118+
for i in v:
119+
if flag["instruction"] == i:
120+
flag["category"] = k
121+
if "category" not in flag:
122+
flag["category"] = "general"
123+
124+
result = {}
125+
for flag in flags:
126+
if flag["category"] not in result:
127+
result[flag["category"]] = []
128+
result[flag["category"]].append({"name": flag["name"].replace("--", r"\-\-"), "description": flag["description"]})
129+
130+
result["details"] = []
131+
for k, v in collected.items():
132+
result["details"].append({"name": k, "file": v[0]["file"], "location": ", ".join(v[0]["lines"])})
133+
if len(v) > 1:
134+
for i in v[1:]:
135+
result["details"].append({"name": "", "file": i["file"], "location": ", ".join(i["lines"])})
136+
137+
return result

docs/cli/genflags.py

Lines changed: 37 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,11 @@
22
import os
33

44
sys.path.insert(1, os.path.join(sys.path[0], "..", ".."))
5+
56
import main
67

8+
from classifier import classify
9+
710

811
class FakeParser(dict):
912
def __init__(self):
@@ -29,21 +32,42 @@ def add_argument(
2932
parser.add_argument("-c", "--config", help="Path to configuration file")
3033
main.add_all_arguments(parser)
3134

35+
classified = classify(parser.flags)
3236

33-
def width(key):
34-
return max(map(lambda f: len(f[key]), parser.flags))
37+
def width_title(key, title):
38+
return max(map(lambda f: len(f[key]), classified[title]))
3539

40+
def print_table(title, flags, intro):
41+
print()
42+
print(intro)
43+
print()
3644

37-
wn = width("name")
38-
wd = width("description")
45+
wn = width_title("name", title)
46+
wd = width_title("description", title)
3947

40-
print(
41-
"""..
42-
Do not modify this file. This file is generated by genflags.py.\n"""
48+
print("=" * wn, "=" * wd)
49+
print("Name".ljust(wn), "Description".ljust(wd))
50+
print("=" * wn, "=" * wd)
51+
for flag in flags:
52+
print(flag["name"].ljust(wn), flag["description"].ljust(wd))
53+
print("=" * wn, "=" * wd)
54+
print()
55+
56+
print_table(
57+
"general",
58+
classified["general"],
59+
intro="**General options**:\n\
60+
Common configurations shared across both linear and neural network trainers."
61+
)
62+
print_table(
63+
"linear",
64+
classified["linear"],
65+
intro="**Linear options**:\n\
66+
Configurations specific to linear trainer."
4367
)
44-
print("=" * wn, "=" * wd)
45-
print("Name".ljust(wn), "Description".ljust(wd))
46-
print("=" * wn, "=" * wd)
47-
for flag in parser.flags:
48-
print(flag["name"].ljust(wn), flag["description"].ljust(wd))
49-
print("=" * wn, "=" * wd)
68+
print_table(
69+
"nn",
70+
classified["nn"],
71+
intro="**Neural network options**:\n\
72+
Configurations specific to torch (neural networks) trainer."
73+
)

main.py

Lines changed: 68 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -11,21 +11,50 @@
1111

1212

1313
def add_all_arguments(parser):
14-
# path / directory
14+
1515
parser.add_argument(
16-
"--result_dir", default="./runs", help="The directory to save checkpoints and logs (default: %(default)s)"
16+
"-h",
17+
"--help",
18+
action="help",
19+
help="Quickstart: https://www.csie.ntu.edu.tw/~cjlin/libmultilabel/cli/quickstart.html",
1720
)
1821

22+
parser.add_argument("--seed", type=int, help="Random seed (default: %(default)s)")
23+
24+
# choose model (linear / nn)
25+
parser.add_argument("--linear", action="store_true", help="Train linear model")
26+
27+
# others
28+
parser.add_argument("--cpu", action="store_true", help="Disable CUDA")
29+
parser.add_argument("--silent", action="store_true", help="Enable silent mode")
30+
parser.add_argument(
31+
"--data_workers", type=int, default=4, help="Use multi-cpu core for data pre-processing (default: %(default)s)"
32+
)
33+
parser.add_argument(
34+
"--embed_cache_dir",
35+
type=str,
36+
help="For parameter search only: path to a directory for storing embeddings for multiple runs. (default: %(default)s)",
37+
)
38+
parser.add_argument(
39+
"--eval", action="store_true", help="Only run evaluation on the test set (default: %(default)s)"
40+
)
41+
parser.add_argument("--checkpoint_path", help="The checkpoint to warm-up with (default: %(default)s)")
42+
1943
# data
20-
parser.add_argument("--data_name", default="unnamed_data", help="Dataset name (default: %(default)s)")
44+
parser.add_argument(
45+
"--data_name",
46+
default="unnamed_data",
47+
help="Dataset name for generating the output directory (default: %(default)s)",
48+
)
2149
parser.add_argument("--training_file", help="Path to training data (default: %(default)s)")
2250
parser.add_argument("--val_file", help="Path to validation data (default: %(default)s)")
23-
parser.add_argument("--test_file", help="Path to test data (default: %(default)s")
51+
parser.add_argument("--test_file", help="Path to test data (default: %(default)s)")
52+
parser.add_argument("--label_file", type=str, help="Path to a file holding all labels (default: %(default)s)")
2453
parser.add_argument(
2554
"--val_size",
2655
type=float,
2756
default=0.2,
28-
help="Training-validation split: a ratio in [0, 1] or an integer for the size of the validation set (default: %(default)s).",
57+
help="Training-validation split: a ratio in [0, 1] or an integer for the size of the validation set (default: %(default)s)",
2958
)
3059
parser.add_argument(
3160
"--min_vocab_freq",
@@ -67,8 +96,24 @@ def add_all_arguments(parser):
6796
help="Whether to add the special tokens for inputs of the transformer-based language model. (default: %(default)s)",
6897
)
6998

99+
# model
100+
parser.add_argument("--model_name", default="unnamed_model", help="Model to be used (default: %(default)s)")
101+
parser.add_argument(
102+
"--init_weight", default="kaiming_uniform", help="Weight initialization to be used (default: %(default)s)"
103+
)
104+
parser.add_argument(
105+
"--loss_function", default="binary_cross_entropy_with_logits", help="Loss function (default: %(default)s)"
106+
)
107+
108+
# pretrained vocab / embeddings
109+
parser.add_argument("--vocab_file", type=str, help="Path to a file holding vocabuaries (default: %(default)s)")
110+
parser.add_argument(
111+
"--embed_file",
112+
type=str,
113+
help="Path to a file holding pre-trained embeddings or the name of the pretrained GloVe embedding (default: %(default)s)",
114+
)
115+
70116
# train
71-
parser.add_argument("--seed", type=int, help="Random seed (default: %(default)s)")
72117
parser.add_argument(
73118
"--epochs", type=int, default=10000, help="The number of epochs to train (default: %(default)s)"
74119
)
@@ -109,15 +154,6 @@ def add_all_arguments(parser):
109154
help="Whether the embeddings of each word is normalized to a unit vector (default: %(default)s)",
110155
)
111156

112-
# model
113-
parser.add_argument("--model_name", default="unnamed_model", help="Model to be used (default: %(default)s)")
114-
parser.add_argument(
115-
"--init_weight", default="kaiming_uniform", help="Weight initialization to be used (default: %(default)s)"
116-
)
117-
parser.add_argument(
118-
"--loss_function", default="binary_cross_entropy_with_logits", help="Loss function (default: %(default)s)"
119-
)
120-
121157
# eval
122158
parser.add_argument(
123159
"--eval_batch_size", type=int, default=256, help="Size of evaluating batches (default: %(default)s)"
@@ -138,28 +174,6 @@ def add_all_arguments(parser):
138174
"--val_metric", default="P@1", help="The metric to select the best model for testing (default: %(default)s)"
139175
)
140176

141-
# pretrained vocab / embeddings
142-
parser.add_argument("--vocab_file", type=str, help="Path to a file holding vocabuaries (default: %(default)s)")
143-
parser.add_argument(
144-
"--embed_file", type=str, help="Path to a file holding pre-trained embeddings or the name of the pretrained GloVe embedding (default: %(default)s)"
145-
)
146-
parser.add_argument("--label_file", type=str, help="Path to a file holding all labels (default: %(default)s)")
147-
148-
# log
149-
parser.add_argument(
150-
"--save_k_predictions",
151-
type=int,
152-
nargs="?",
153-
const=100,
154-
default=0,
155-
help="Save top k predictions on test set. k=%(const)s if not specified. (default: %(default)s)",
156-
)
157-
parser.add_argument(
158-
"--predict_out_path",
159-
default="./predictions.txt",
160-
help="Path to the output file holding label results (default: %(default)s)",
161-
)
162-
163177
# auto-test
164178
parser.add_argument(
165179
"--limit_train_batches",
@@ -180,24 +194,27 @@ def add_all_arguments(parser):
180194
help="Percentage of test dataset to use for auto-testing (default: %(default)s)",
181195
)
182196

183-
# others
184-
parser.add_argument("--cpu", action="store_true", help="Disable CUDA")
185-
parser.add_argument("--silent", action="store_true", help="Enable silent mode")
197+
# log
186198
parser.add_argument(
187-
"--data_workers", type=int, default=4, help="Use multi-cpu core for data pre-processing (default: %(default)s)"
199+
"--save_k_predictions",
200+
type=int,
201+
nargs="?",
202+
const=100,
203+
default=0,
204+
help="Save top k predictions on test set. k=%(const)s if not specified. (default: %(default)s)",
188205
)
189206
parser.add_argument(
190-
"--embed_cache_dir",
191-
type=str,
192-
help="For parameter search only: path to a directory for storing embeddings for multiple runs. (default: %(default)s)",
207+
"--predict_out_path",
208+
default="./predictions.txt",
209+
help="Path to the output file holding label results (default: %(default)s)",
193210
)
211+
212+
# path / directory
194213
parser.add_argument(
195-
"--eval", action="store_true", help="Only run evaluation on the test set (default: %(default)s)"
214+
"--result_dir", default="./runs", help="The directory to save checkpoints and logs (default: %(default)s)"
196215
)
197-
parser.add_argument("--checkpoint_path", help="The checkpoint to warm-up with (default: %(default)s)")
198216

199217
# linear options
200-
parser.add_argument("--linear", action="store_true", help="Train linear model")
201218
parser.add_argument(
202219
"--data_format",
203220
type=str,
@@ -224,7 +241,10 @@ def add_all_arguments(parser):
224241
"--tree_max_depth", type=int, default=10, help="Maximum depth of the tree (default: %(default)s)"
225242
)
226243
parser.add_argument(
227-
"--tree_ensemble_models", type=int, default=1, help="Number of models in the tree ensemble (default: %(default)s)"
244+
"--tree_ensemble_models",
245+
type=int,
246+
default=1,
247+
help="Number of models in the tree ensemble (default: %(default)s)",
228248
)
229249
parser.add_argument(
230250
"--beam_width",
@@ -239,13 +259,6 @@ def add_all_arguments(parser):
239259
default=8,
240260
help="the maximal number of labels inside a cluster (default: %(default)s)",
241261
)
242-
parser.add_argument(
243-
"-h",
244-
"--help",
245-
action="help",
246-
help="If you are trying to specify network config such as dropout or activation or config of the learning rate scheduler, use a yaml file instead. "
247-
"See example configs in example_config",
248-
)
249262

250263

251264
def get_config():

0 commit comments

Comments
 (0)