-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathgenerate_eval_metrics.py
153 lines (138 loc) · 7.87 KB
/
generate_eval_metrics.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
import os
from tqdm import tqdm
from utils.geocoding_utils import compute_api_distance
from utils.metric_utils import compute_basic_metrics, bootstrap_f1_error_bars, compute_withheld_leaked
from utils.format_utils import print_table
# args for experiments
from argparse import ArgumentParser
parser = ArgumentParser()
parser.add_argument("--basic_metrics", action="store_true",
help="Calculate basic metrics")
parser.add_argument("--privacy_utility", action="store_true",
help="Calculate withheld and leaked proportions")
parser.add_argument("--geocoding_distance", action="store_true",
help="Calculate geocoding distance error")
parser.add_argument("--all", action="store_true", help="Run all experiments")
parser.add_argument("--recompute_geocoding_results",
action="store_true", help="Recompute geocoding results")
parser.add_argument('--agents', nargs='+', help='List of agents to evaluate')
args = parser.parse_args()
GRANULARITIES = ["country", "city", "neighborhood",
"exact_location_name", "exact_gps_coordinates"]
def get_agent_results(results_dir):
agent_results = {}
for filename in os.listdir(results_dir):
if filename.endswith(".jsonl"):
# find index where granularity starts
granularity_idx = filename.find("_granularity=")
if granularity_idx == -1:
raise ValueError(
"Granularity not found in filename: ", filename)
granularity = filename[granularity_idx +
len("_granularity="): -len(".jsonl")]
model_name = f"{filename[:granularity_idx]}_{granularity}"
agent_results[model_name] = {
"granularity": granularity, "filename": os.path.join(results_dir, filename)}
return agent_results
if __name__ == "__main__":
# Gather results from all models
baselines_results = get_agent_results("moderation_decisions_baselines")
base_model_results = get_agent_results("moderation_decisions_prompted")
finetuned_model_results = get_agent_results(
"moderation_decisions_finetuned")
# Combine all results
all_model_results = {**baselines_results, **
base_model_results, **finetuned_model_results}
# Format allowed agent names for checking
formatted_allowed_agent_names = []
if args.agents:
formatted_allowed_agent_names = [f"{agent}_{granularity}"
for agent in args.agents for granularity in GRANULARITIES]
else:
formatted_allowed_agent_names = list(all_model_results.keys())
# Run experiments:
if args.basic_metrics or args.all:
# Experiment #1a: Basic Metrics
granularity_results_basic = {granularity: []
for granularity in GRANULARITIES}
print('Calculating basic metrics...')
for model, model_results_dict in tqdm(all_model_results.items()):
granularity, filename = model_results_dict["granularity"], model_results_dict["filename"]
if model not in formatted_allowed_agent_names:
continue
recall, precision, f1 = compute_basic_metrics(
granularity=granularity, answers_file=filename)
granularity_results_basic[granularity].append(
{"model": model, "recall": recall, "precision": precision, "f1": f1})
# Experiment 1b: Error Bars with Bootstrap Method
print('Calculating errors using bootstrap method...')
# start tqdm bar
bar = tqdm(total=sum([len(results)
for results in granularity_results_basic.values()]))
for granularity, results in granularity_results_basic.items():
for result in results:
model, filename = result["model"], all_model_results[result["model"]]["filename"]
if model not in formatted_allowed_agent_names:
continue
_, stderr = bootstrap_f1_error_bars(
granularity=granularity, answers_file=filename)
result.update({"f1_stderr": stderr})
bar.update(1)
bar.close()
# print formatted table
column_display_names_basic = ['Agent', 'Recall', 'Precision', 'F1']
column_keys_basic = ['model', 'recall', 'precision', 'f1']
column_widths_basic = [65, 10, 10, 20]
print_table('Experiment #1: Basic Metrics', granularity_results_basic, column_display_names_basic,
column_keys_basic, column_widths_basic, baselines_results,
base_model_results, finetuned_model_results, "f1")
if args.privacy_utility or args.all:
# EXPERIMENT #2: Privacy-Utility Tradeoff
granularity_results_withhold_leak = {
granularity: [] for granularity in GRANULARITIES}
print('Calculating withheld and leaked proportions...')
for model, model_results_dict in all_model_results.items():
granularity, filename = model_results_dict["granularity"], model_results_dict["filename"]
if model not in formatted_allowed_agent_names:
continue
withheld_proportion, leaked_proportion = compute_withheld_leaked(
filename, granularity)
granularity_results_withhold_leak[granularity].append(
{"model": model, "withheld_proportion": withheld_proportion, "leaked_proportion": leaked_proportion})
# print formatted table
column_display_names_withhold_leak = [
'Agent', 'Withheld Proportion', 'Leaked Proportion']
column_keys_withhold_leak = [
'model', 'withheld_proportion', 'leaked_proportion']
column_widths_withhold_leak = [60, 20, 20]
print_table('Experiment #2: Privacy-Utility Tradeoff', granularity_results_withhold_leak,
column_display_names_withhold_leak, column_keys_withhold_leak, column_widths_withhold_leak,
baselines_results, base_model_results, finetuned_model_results)
if args.geocoding_distance or args.all:
# # EXPERIMENT #3: Geocoding Distance Error
granularity_results_api_distance = {granularity: [
] for granularity in GRANULARITIES if granularity != "exact_gps_coordinates"}
print('Calculating geocoding distance error...')
for model, model_results_dict in tqdm(all_model_results.items()):
granularity, filename = model_results_dict["granularity"], model_results_dict["filename"]
if model not in formatted_allowed_agent_names:
continue
if granularity == "exact_gps_coordinates":
continue
distance_thresholds, all_distances = compute_api_distance(
filename, granularity, model_name=model, recompute=args.recompute_geocoding_results)
total_guesses = distance_thresholds['all']
results_dict = {"model": model}
results_dict.update({f"within {threshold} km": f"{round(num_guesses / total_guesses * 100, 1)} %" for threshold,
num_guesses in distance_thresholds.items() if threshold != 'all'})
granularity_results_api_distance[granularity].append(results_dict)
# print a formatted table
column_display_names_api_distance = [
key.title() for key in granularity_results_api_distance["country"][0].keys()]
column_keys_api_distance = list(
granularity_results_api_distance["country"][0].keys())
column_widths_api_distance = [65] + [15] * \
(len(column_display_names_api_distance) - 1)
print_table('Experiment #3: Geocoding Distance Error', granularity_results_api_distance,
column_display_names_api_distance, column_keys_api_distance, column_widths_api_distance,
baselines_results, base_model_results, finetuned_model_results)