-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathutils.py
53 lines (41 loc) · 1.86 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.colors import ListedColormap
from mpl_toolkits.mplot3d import Axes3D
def draw_line(p1, p2, style="-k", linewidth=1):
plt.plot([p1[0], p2[0]], [p1[1], p2[1]], style, linewidth=linewidth)
def plot_data_points(X, idx):
# Define colormap to match Figure 1 in the notebook
cmap = ListedColormap(["red", "green", "blue"])
c = cmap(idx)
# plots data points in X, coloring them so that those with the same
# index assignments in idx have the same color
plt.scatter(X[:, 0], X[:, 1], facecolors='none', edgecolors=c, linewidth=0.1, alpha=0.7)
def plot_progress_kMeans(X, centroids, previous_centroids, idx, K, i):
# Plot the examples
plot_data_points(X, idx)
# Plot the centroids as black 'x's
plt.scatter(centroids[:, 0], centroids[:, 1], marker='x', c='k', linewidths=3)
# Plot history of the centroids with lines
for j in range(centroids.shape[0]):
draw_line(centroids[j, :], previous_centroids[j, :])
plt.title("Iteration number %d" %i)
def plot_kMeans_RGB(X, centroids, idx, K):
# Plot the colors and centroids in a 3D space
fig = plt.figure(figsize=(16, 16))
ax = fig.add_subplot(221, projection='3d')
ax.scatter(*X.T*255, zdir='z', depthshade=False, s=.3, c=X)
ax.scatter(*centroids.T*255, zdir='z', depthshade=False, s=500, c='red', marker='x', lw=3)
ax.set_xlabel('R value - Redness')
ax.set_ylabel('G value - Greenness')
ax.set_zlabel('B value - Blueness')
ax.w_yaxis.set_pane_color((0., 0., 0., .2))
ax.set_title("Original colors and their color clusters' centroids")
plt.show()
def show_centroid_colors(centroids):
palette = np.expand_dims(centroids, axis=0)
num = np.arange(0,len(centroids))
plt.figure(figsize=(16, 16))
plt.xticks(num)
plt.yticks([])
plt.imshow(palette)