Skip to content

Commit 3a0e32c

Browse files
authored
Merge pull request #26 from Abies-0/master
Classify flags into general, linear, nn categories
2 parents c033a52 + 3d725f5 commit 3a0e32c

File tree

4 files changed

+211
-68
lines changed

4 files changed

+211
-68
lines changed

docs/cli/classifier.py

Lines changed: 102 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,102 @@
1+
import os
2+
import sys
3+
import glob
4+
import re
5+
from pathlib import Path
6+
from collections import defaultdict
7+
8+
lib_path = Path.cwd().parent
9+
sys.path.insert(0, str(lib_path))
10+
11+
12+
def classify_file_category(path):
13+
relative_path = Path(path).relative_to(lib_path)
14+
filename = "/".join(relative_path.parts[1:]) or relative_path.as_posix()
15+
16+
if filename.startswith("linear"):
17+
return "linear"
18+
if filename.startswith(("torch", "nn")):
19+
return "nn"
20+
return "general"
21+
22+
23+
def fetch_option_flags(flags):
24+
flag_list = []
25+
26+
for flag in flags:
27+
flag_list.append(
28+
{
29+
"name": flag["name"].replace("\\", ""),
30+
"instruction": flag["name"].split("-")[-1],
31+
"description": flag["description"],
32+
}
33+
)
34+
35+
return flag_list
36+
37+
38+
def fetch_all_files():
39+
main_files = [
40+
os.path.join(lib_path, "main.py"),
41+
os.path.join(lib_path, "linear_trainer.py"),
42+
os.path.join(lib_path, "torch_trainer.py"),
43+
]
44+
lib_files = glob.glob(os.path.join(lib_path, "libmultilabel/**/*.py"), recursive=True)
45+
file_set = set(map(os.path.abspath, main_files + lib_files))
46+
return file_set
47+
48+
49+
def find_config_usages_in_file(file_path, allowed_keys, category_set):
50+
pattern = re.compile(r"\bconfig\.([a-zA-Z_][a-zA-Z0-9_]*)")
51+
52+
with open(file_path, "r", encoding="utf-8") as f:
53+
lines = f.readlines()
54+
55+
if file_path.endswith("main.py"):
56+
for idx in range(len(lines)):
57+
if lines[idx].startswith("def main("):
58+
lines = lines[idx:]
59+
break
60+
all_str = " ".join(lines)
61+
matches = set(pattern.findall(all_str)) & allowed_keys
62+
63+
category = classify_file_category(file_path)
64+
for key in matches:
65+
category_set[category].add(key)
66+
67+
68+
def move_duplicates_together(data):
69+
duplicates = (data["general"] & data["linear"]) | (data["general"] & data["nn"]) | (data["linear"] & data["nn"])
70+
data["general"].update(duplicates)
71+
data["linear"] -= duplicates
72+
data["nn"] -= duplicates
73+
74+
75+
def classify(raw_flags):
76+
category_set = {"general": set(), "linear": set(), "nn": set()}
77+
78+
flags = fetch_option_flags(raw_flags)
79+
allowed_keys = set(flag["instruction"] for flag in flags)
80+
file_set = fetch_all_files()
81+
82+
for file_path in file_set:
83+
find_config_usages_in_file(file_path, allowed_keys, category_set)
84+
85+
move_duplicates_together(category_set)
86+
87+
result = defaultdict(list)
88+
for flag in raw_flags:
89+
instr = flag["name"].replace("\\", "").split("-")[-1]
90+
flag_name = flag["name"].replace("--", r"\-\-")
91+
92+
matched = False
93+
for category, keys in category_set.items():
94+
if instr in keys:
95+
result[category].append({"name": flag_name, "description": flag["description"]})
96+
matched = True
97+
break
98+
99+
if not matched:
100+
result["general"].append({"name": flag_name, "description": flag["description"]})
101+
102+
return result

docs/cli/genflags.py

Lines changed: 40 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,11 @@
22
import os
33

44
sys.path.insert(1, os.path.join(sys.path[0], "..", ".."))
5+
56
import main
67

8+
from classifier import classify
9+
710

811
class FakeParser(dict):
912
def __init__(self):
@@ -29,21 +32,45 @@ def add_argument(
2932
parser.add_argument("-c", "--config", help="Path to configuration file")
3033
main.add_all_arguments(parser)
3134

35+
classified = classify(parser.flags)
36+
37+
38+
def width_title(key, title):
39+
return max(map(lambda f: len(f[key]), classified[title]))
3240

33-
def width(key):
34-
return max(map(lambda f: len(f[key]), parser.flags))
3541

42+
def print_table(title, flags, intro):
43+
print()
44+
print(intro)
45+
print()
3646

37-
wn = width("name")
38-
wd = width("description")
47+
wn = width_title("name", title)
48+
wd = width_title("description", title)
3949

40-
print(
41-
"""..
42-
Do not modify this file. This file is generated by genflags.py.\n"""
50+
print("=" * wn, "=" * wd)
51+
print("Name".ljust(wn), "Description".ljust(wd))
52+
print("=" * wn, "=" * wd)
53+
for flag in flags:
54+
print(flag["name"].ljust(wn), flag["description"].ljust(wd))
55+
print("=" * wn, "=" * wd)
56+
print()
57+
58+
59+
print_table(
60+
"general",
61+
classified["general"],
62+
intro="**General options**:\n\
63+
Common configurations shared across both linear and neural network trainers.",
64+
)
65+
print_table(
66+
"linear",
67+
classified["linear"],
68+
intro="**Linear options**:\n\
69+
Configurations specific to linear trainer.",
70+
)
71+
print_table(
72+
"nn",
73+
classified["nn"],
74+
intro="**Neural network options**:\n\
75+
Configurations specific to torch (neural networks) trainer.",
4376
)
44-
print("=" * wn, "=" * wd)
45-
print("Name".ljust(wn), "Description".ljust(wd))
46-
print("=" * wn, "=" * wd)
47-
for flag in parser.flags:
48-
print(flag["name"].ljust(wn), flag["description"].ljust(wd))
49-
print("=" * wn, "=" * wd)

docs/conf.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@
4949
"examples_dirs": "./examples", # path to your example scripts
5050
"gallery_dirs": "auto_examples", # path to where to save gallery generated output
5151
"plot_gallery": False,
52+
"write_computation_times": False,
5253
}
5354

5455
# bibtex files

main.py

Lines changed: 68 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -11,21 +11,50 @@
1111

1212

1313
def add_all_arguments(parser):
14-
# path / directory
14+
1515
parser.add_argument(
16-
"--result_dir", default="./runs", help="The directory to save checkpoints and logs (default: %(default)s)"
16+
"-h",
17+
"--help",
18+
action="help",
19+
help="Quickstart: https://www.csie.ntu.edu.tw/~cjlin/libmultilabel/cli/quickstart.html",
1720
)
1821

22+
parser.add_argument("--seed", type=int, help="Random seed (default: %(default)s)")
23+
24+
# choose model (linear / nn)
25+
parser.add_argument("--linear", action="store_true", help="Train linear model")
26+
27+
# others
28+
parser.add_argument("--cpu", action="store_true", help="Disable CUDA")
29+
parser.add_argument("--silent", action="store_true", help="Enable silent mode")
30+
parser.add_argument(
31+
"--data_workers", type=int, default=4, help="Use multi-cpu core for data pre-processing (default: %(default)s)"
32+
)
33+
parser.add_argument(
34+
"--embed_cache_dir",
35+
type=str,
36+
help="For parameter search only: path to a directory for storing embeddings for multiple runs. (default: %(default)s)",
37+
)
38+
parser.add_argument(
39+
"--eval", action="store_true", help="Only run evaluation on the test set (default: %(default)s)"
40+
)
41+
parser.add_argument("--checkpoint_path", help="The checkpoint to warm-up with (default: %(default)s)")
42+
1943
# data
20-
parser.add_argument("--data_name", default="unnamed_data", help="Dataset name (default: %(default)s)")
44+
parser.add_argument(
45+
"--data_name",
46+
default="unnamed_data",
47+
help="Dataset name for generating the output directory (default: %(default)s)",
48+
)
2149
parser.add_argument("--training_file", help="Path to training data (default: %(default)s)")
2250
parser.add_argument("--val_file", help="Path to validation data (default: %(default)s)")
23-
parser.add_argument("--test_file", help="Path to test data (default: %(default)s")
51+
parser.add_argument("--test_file", help="Path to test data (default: %(default)s)")
52+
parser.add_argument("--label_file", type=str, help="Path to a file holding all labels (default: %(default)s)")
2453
parser.add_argument(
2554
"--val_size",
2655
type=float,
2756
default=0.2,
28-
help="Training-validation split: a ratio in [0, 1] or an integer for the size of the validation set (default: %(default)s).",
57+
help="Training-validation split: a ratio in [0, 1] or an integer for the size of the validation set (default: %(default)s)",
2958
)
3059
parser.add_argument(
3160
"--min_vocab_freq",
@@ -67,8 +96,24 @@ def add_all_arguments(parser):
6796
help="Whether to add the special tokens for inputs of the transformer-based language model. (default: %(default)s)",
6897
)
6998

99+
# model
100+
parser.add_argument("--model_name", default="unnamed_model", help="Model to be used (default: %(default)s)")
101+
parser.add_argument(
102+
"--init_weight", default="kaiming_uniform", help="Weight initialization to be used (default: %(default)s)"
103+
)
104+
parser.add_argument(
105+
"--loss_function", default="binary_cross_entropy_with_logits", help="Loss function (default: %(default)s)"
106+
)
107+
108+
# pretrained vocab / embeddings
109+
parser.add_argument("--vocab_file", type=str, help="Path to a file holding vocabuaries (default: %(default)s)")
110+
parser.add_argument(
111+
"--embed_file",
112+
type=str,
113+
help="Path to a file holding pre-trained embeddings or the name of the pretrained GloVe embedding (default: %(default)s)",
114+
)
115+
70116
# train
71-
parser.add_argument("--seed", type=int, help="Random seed (default: %(default)s)")
72117
parser.add_argument(
73118
"--epochs", type=int, default=10000, help="The number of epochs to train (default: %(default)s)"
74119
)
@@ -109,15 +154,6 @@ def add_all_arguments(parser):
109154
help="Whether the embeddings of each word is normalized to a unit vector (default: %(default)s)",
110155
)
111156

112-
# model
113-
parser.add_argument("--model_name", default="unnamed_model", help="Model to be used (default: %(default)s)")
114-
parser.add_argument(
115-
"--init_weight", default="kaiming_uniform", help="Weight initialization to be used (default: %(default)s)"
116-
)
117-
parser.add_argument(
118-
"--loss_function", default="binary_cross_entropy_with_logits", help="Loss function (default: %(default)s)"
119-
)
120-
121157
# eval
122158
parser.add_argument(
123159
"--eval_batch_size", type=int, default=256, help="Size of evaluating batches (default: %(default)s)"
@@ -138,28 +174,6 @@ def add_all_arguments(parser):
138174
"--val_metric", default="P@1", help="The metric to select the best model for testing (default: %(default)s)"
139175
)
140176

141-
# pretrained vocab / embeddings
142-
parser.add_argument("--vocab_file", type=str, help="Path to a file holding vocabuaries (default: %(default)s)")
143-
parser.add_argument(
144-
"--embed_file", type=str, help="Path to a file holding pre-trained embeddings or the name of the pretrained GloVe embedding (default: %(default)s)"
145-
)
146-
parser.add_argument("--label_file", type=str, help="Path to a file holding all labels (default: %(default)s)")
147-
148-
# log
149-
parser.add_argument(
150-
"--save_k_predictions",
151-
type=int,
152-
nargs="?",
153-
const=100,
154-
default=0,
155-
help="Save top k predictions on test set. k=%(const)s if not specified. (default: %(default)s)",
156-
)
157-
parser.add_argument(
158-
"--predict_out_path",
159-
default="./predictions.txt",
160-
help="Path to the output file holding label results (default: %(default)s)",
161-
)
162-
163177
# auto-test
164178
parser.add_argument(
165179
"--limit_train_batches",
@@ -180,24 +194,27 @@ def add_all_arguments(parser):
180194
help="Percentage of test dataset to use for auto-testing (default: %(default)s)",
181195
)
182196

183-
# others
184-
parser.add_argument("--cpu", action="store_true", help="Disable CUDA")
185-
parser.add_argument("--silent", action="store_true", help="Enable silent mode")
197+
# log
186198
parser.add_argument(
187-
"--data_workers", type=int, default=4, help="Use multi-cpu core for data pre-processing (default: %(default)s)"
199+
"--save_k_predictions",
200+
type=int,
201+
nargs="?",
202+
const=100,
203+
default=0,
204+
help="Save top k predictions on test set. k=%(const)s if not specified. (default: %(default)s)",
188205
)
189206
parser.add_argument(
190-
"--embed_cache_dir",
191-
type=str,
192-
help="For parameter search only: path to a directory for storing embeddings for multiple runs. (default: %(default)s)",
207+
"--predict_out_path",
208+
default="./predictions.txt",
209+
help="Path to the output file holding label results (default: %(default)s)",
193210
)
211+
212+
# path / directory
194213
parser.add_argument(
195-
"--eval", action="store_true", help="Only run evaluation on the test set (default: %(default)s)"
214+
"--result_dir", default="./runs", help="The directory to save checkpoints and logs (default: %(default)s)"
196215
)
197-
parser.add_argument("--checkpoint_path", help="The checkpoint to warm-up with (default: %(default)s)")
198216

199217
# linear options
200-
parser.add_argument("--linear", action="store_true", help="Train linear model")
201218
parser.add_argument(
202219
"--data_format",
203220
type=str,
@@ -224,7 +241,10 @@ def add_all_arguments(parser):
224241
"--tree_max_depth", type=int, default=10, help="Maximum depth of the tree (default: %(default)s)"
225242
)
226243
parser.add_argument(
227-
"--tree_ensemble_models", type=int, default=1, help="Number of models in the tree ensemble (default: %(default)s)"
244+
"--tree_ensemble_models",
245+
type=int,
246+
default=1,
247+
help="Number of models in the tree ensemble (default: %(default)s)",
228248
)
229249
parser.add_argument(
230250
"--beam_width",
@@ -239,13 +259,6 @@ def add_all_arguments(parser):
239259
default=8,
240260
help="the maximal number of labels inside a cluster (default: %(default)s)",
241261
)
242-
parser.add_argument(
243-
"-h",
244-
"--help",
245-
action="help",
246-
help="If you are trying to specify network config such as dropout or activation or config of the learning rate scheduler, use a yaml file instead. "
247-
"See example configs in example_config",
248-
)
249262

250263

251264
def get_config():

0 commit comments

Comments
 (0)