-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathvisualize.py
56 lines (39 loc) · 1.76 KB
/
visualize.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
from __future__ import print_function
import os
import yaml
import argparse
import numpy as np
import seaborn as sns
import matplotlib.pyplot as plt
from sklearn.manifold import TSNE
from data import DataLoader
from model import get_model
def scatter(x, labels, config):
palette = np.array(sns.color_palette("hls", config["data"]["num_classes"]))
plt.switch_backend('agg')
fig, ax = plt.subplots()
ax.scatter(x[:,0], x[:,1], lw=0, s=40, alpha=0.2, c=palette[labels.astype(np.int)])
for idx in range(config["data"]["num_classes"]):
xtext, ytext = np.median(x[labels == idx, :], axis=0)
txt = ax.text(xtext, ytext, str(idx), fontsize=20)
plt.title("{} T-SNE".format(config["run-title"]))
plt.savefig(os.path.join(config["paths"]["save"], "tsne.png"))
if __name__ == "__main__":
parser = argparse.ArgumentParser(description='Model Paramaters')
parser.add_argument('-c', '--config', type=str, default="config.yaml", help='path of config file')
args = parser.parse_args()
with open(args.config, 'r') as file:
config = yaml.load(file)
paths = config["paths"]
data = config["data"]
dataloader = DataLoader(config)
dataloader.load()
input_shape = (data["imsize"], data["imsize"], data["imchannel"])
model = get_model(input_shape, config, top=False)
model.load_weights(paths["load"], by_name=True)
X_batch, y_batch = dataloader.get_random_batch(k = -1)
#embeddings = X_batch.reshape(-1, 784)
embeddings = model.predict(X_batch, batch_size=config["train"]["batch-size"], verbose=1)
tsne = TSNE(n_components=2, perplexity=config["tsne"]["perplexity"], verbose=1, n_iter=config["tsne"]["n_iter"])
tsne_embeds = tsne.fit_transform(embeddings)
scatter(tsne_embeds, y_batch, config)