-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathdistribution_visualization.py
More file actions
228 lines (189 loc) · 9.14 KB
/
Copy pathdistribution_visualization.py
File metadata and controls
228 lines (189 loc) · 9.14 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
import os
import argparse
import json
import typing
from typing import List, Union
import re
from matplotlib import pyplot as plt
import numpy as np
import seaborn as sns
import torch
from src.utils import logging, print_
def main():
"""
Visualize the distribution metrics of the data.
"""
# Parse arguments
parser = argparse.ArgumentParser()
parser.add_argument('-e','--experiments', type=str, nargs='+', default=None, help='Experiment names to resume the visualization from')
parser.add_argument('-c','--checkpoints', type=str, nargs='+', default=None, help='Checkpoint names to resume the visualization from')
parser.add_argument('-d', '--dataset', type=str, default='h36m', help='Dataset to use in visualization')
args = parser.parse_args()
# Set up default values
if args.experiments is None:
if 'CURRENT_EXP' not in os.environ.keys():
print_('No experiment name passed. Visualization aborted!')
return
print_('Take experiment name from environment.')
exp_name = os.environ['CURRENT_EXP']
if 'CURRENT_RUN' not in os.environ.keys():
print_('No run name passed. Visualization aborted!')
return
print_('Take run name from environment.')
run_name = os.environ['CURRENT_RUN']
exp_ids = [f'{exp_name}.{run_name}']
else:
exp_ids = args.experiments
dataset = args.dataset
if args.checkpoints is None:
checkpoint_names = ['final' for i in range(len(exp_ids))]
else:
checkpoint_names = args.checkpoints
# Check if the number of checkpoint names matches the number of experiment names or is 1
if len(checkpoint_names) == 1:
checkpoint_names = checkpoint_names * len(exp_ids)
elif len(checkpoint_names) != len(exp_ids):
raise ValueError(f'Number of checkpoint names ({len(checkpoint_names)}) does not match number of experiment names ({len(exp_ids)}).')
data_overall = []
for i, exp_id in enumerate(exp_ids):
exp_name, run_name = exp_id.split('.')
# Set up logging (only for getting the path to the log folder and visualization folder)
try:
data = load_data(exp_id, checkpoint=checkpoint_names[i], dataset=dataset)
except (IOError, ValueError) as e:
print_(e.message)
return
data_overall.append(data)
# Set up visualization path
vis_path = os.path.join(os.getcwd(), 'distr_visualizations')
if not os.path.exists(vis_path):
os.makedirs(vis_path)
# Load baselines
baselines = load_baselines(dataset=dataset)
# Create plot of distribution metrics
fig, axs = create_plots(data_overall, exp_ids, baselines=baselines)
# Create list of experiment names concatenated with respective checkpoint names
exp_runs = [f'{exp_id}.{checkpoint}' for exp_id, checkpoint in zip(exp_ids, checkpoint_names)]
# Save the plot
fig.savefig(os.path.join(vis_path, f'distribution_metrics_{exp_runs}_{dataset}.eps'), format='eps')
# Create plot containing only the figures legend
fig_legend = plt.figure(figsize=(36, 3))
ax = fig_legend.add_subplot(111)
ax.axis('off')
ax.legend(*axs[0].get_legend_handles_labels(), loc='center', fontsize=76, ncol=4, columnspacing=0.8)
fig_legend.savefig(os.path.join(vis_path, f'distribution_metrics_legend_{exp_runs}_{dataset}.eps'), format='eps')
print_('Visualization of distribution metrics finished.')
def load_baselines(dataset: str='h36m'):
"""
Load the baseline distribution metrics from file.
"""
baseline_names = [f'entropy_s16_{dataset}_norm_False_abs_True', f'kld_s16_{dataset}_norm_False_abs_True', f'entropy_s16_{dataset}_norm_True_abs_False', f'kld_s16_{dataset}_norm_True_abs_False']
baselines = []
for baseline_name in baseline_names:
try:
# Load the data
path = os.path.join('configurations', 'distribution_values', f'{baseline_name}.pt')
baselines.append(torch.load(path).numpy())
except IOError as e:
e.message = f'Could not load the baseline distribution metrics from file: {path}.'
raise e
return baselines
def load_data(exp_id: str, action: str='overall', checkpoint: str='final', dataset: str='h36m'):
"""
Load the distribution metrics from file.
"""
exp_name, run_name = exp_id.split('.')
logger = logging.Logger(exp_name=exp_name, run_name=run_name)
path = os.path.join(os.getcwd(), logger.get_path('log'))
try:
# Load the data
with open(os.path.join(path, f'eval_results_distribution_{checkpoint}_{dataset}.json'), 'r') as f:
data = json.load(f)
except IOError as e:
e.message = f'Could not load the distribution metrics from file: {path}.'
raise e
if action not in data.keys():
raise ValueError(f'No {action} distribution metrics found in the data. Not yet implemented for this action type.')
data = data[action]
return data
def create_plots(data: Union[dict, List[dict]], exp_ids: List[str], baselines:List=None, plot_size: tuple=(35, 10.5)):
"""
Create one plot of predictions' entropy vs. baseline entropy and one of predictions' kld vs baseline kld.
"""
data = data if isinstance(data, list) else [data]
fonttype = 'Computer Modern Roman'
plt.rcParams.update({
"text.usetex": True,
"font.family": fonttype,
})
fig, axs = plt.subplots(1, 2, figsize=(plot_size[0], plot_size[1]))
# Extract the baseline values
entropy_norm_false = baselines[0]
kld_norm_false = baselines[1]
entropy_norm_true = baselines[2]
kld_norm_true = baselines[3]
fontsize = 76
line_width = 8
# Create linspace for each baseline value
length_kld = len(data[0]['kld'])
length_entropy = len(data[0]['entropy'])
entropy_baseline = np.linspace(entropy_norm_false, entropy_norm_false, length_entropy)
kld_baseline = np.linspace(kld_norm_false, kld_norm_false, length_kld)
entropy_norm_true = np.linspace(entropy_norm_true, entropy_norm_true, length_entropy)
kld_norm_true = np.linspace(kld_norm_true, kld_norm_true, length_kld)
# Nicer less technical legend names
# (remove all the numbers and underscores and the word model)
exp_ids = [re.sub(r'[0-9]+', '', exp_id) for exp_id in exp_ids]
exp_ids = [re.sub(r'_', ' ', exp_id) for exp_id in exp_ids]
exp_ids = [re.sub(r'model', '', exp_id) for exp_id in exp_ids]
if len(exp_ids) == 2 and any('final' in exp_id for exp_id in exp_ids) and any('global' in exp_id for exp_id in exp_ids):
# Get two colors for plotting the exact report visualization
palette = [(0.34509803921568627, 0.4666666666666667, 0.5725490196078431), (0.9254901960784314, 0.3058823529411765, 0.12549019607843137)]
# Label curves
labels = ['Global', 'Local']
else:
palette = sns.color_palette("colorblind", len(exp_ids) + 2)
labels = [exp_id.split('.')[1].strip() for exp_id in exp_ids]
# Create a plot for each metric pair
for i, (result, exp_id) in enumerate(reversed(list(zip(data, exp_ids)))):
exp_name, run_name = exp_id.split('.')
run_name = run_name.strip()
# Extract the data
entropy = result['entropy']
kld = result['kld']
# Create the plots
axs[0].plot(entropy, label=labels[i], color=palette[i], linewidth=line_width)
axs[1].plot(range(len(kld)),kld, label=labels[i], color=palette[i], linewidth=line_width)
axs[0].plot(entropy_baseline, '--', label='Baseline global', color=palette[-2], linewidth=line_width)
axs[1].plot(kld_baseline, '--', label='Baseline global', color=palette[-2], linewidth=line_width)
axs[0].plot(entropy_norm_true, '--', label='Baseline local', color=palette[-1], linewidth=line_width)
axs[1].plot(kld_norm_true, '--', label='Baseline local', color=palette[-1], linewidth=line_width)
# Convert xaxis of entropy plot from frame number to seconds (with a framerate of 25fps)
axs[0].set_xlabel('Time (s)', fontsize=fontsize)
axs[1].set_xlabel('Time (s)', fontsize=fontsize)
# Set yaxis labels
axs[0].set_ylabel('PS Entropy', fontsize=fontsize)
axs[1].set_ylabel('PS KLD', fontsize=fontsize)
# Set xticks so that every other second is shown
axs[0].set_xticks(range(0, len(entropy), 50))
axs[1].set_xticks(range(0, len(kld), 2))
# Set xticklabels
axs[0].set_xticklabels(range(0, len(entropy)//25, 2))
axs[1].set_xticklabels(range(0, len(kld), 2))
# Increase font size of all text
for ax in axs:
ax.tick_params(axis='both', which='major', labelsize=fontsize)
ax.tick_params(axis='both', which='minor', labelsize=fontsize)
#ax.legend(fontsize=fontsize).set_visible(False)
# Also for titles
ax.title.set_size(fontsize)
# Add grid to the plots
ax.grid(True, linewidth=line_width-5)
# Increase horizontal space between subplots
fig.subplots_adjust(wspace=0.2)
# Add whitespace below the plots
fig.subplots_adjust(bottom=0.22, left = 0.10, right = 0.99)
# Return the figure
return fig, axs
if __name__ == '__main__':
main()