Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
816 changes: 816 additions & 0 deletions predictions_data_vis.ipynb

Large diffs are not rendered by default.

96 changes: 96 additions & 0 deletions predictions_data_vis.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from matplotlib import cm
from matplotlib.colors import Normalize

def get_color_for_value(value, cmap, min_value, max_value):
norm = Normalize(vmin=min_value, vmax=max_value)
return cmap(norm(value))

def create_color_map(color_matrix, cmap=cm.coolwarm):
norm = Normalize(vmin=np.min(color_matrix), vmax=np.max(color_matrix))
cmap = cm.coolwarm # You can choose a different colormap if desired
norm = Normalize(vmin=min(map(min, color_matrix)), vmax=max(map(max, color_matrix)))
sm = plt.cm.ScalarMappable(cmap=cmap, norm=norm)
sm.set_array([])
return cmap, sm


def draw_grid(matrix, error_matrix, color_matrix, filename='output_plot.pdf'): # Set the default filename
# # Create a color map
# norm = Normalize(vmin=np.min(color_matrix), vmax=np.max(color_matrix))
# cmap = cm.coolwarm # You can choose a different colormap if desired
# norm = Normalize(vmin=min(map(min, color_matrix)), vmax=max(map(max, color_matrix)))
# sm = plt.cm.ScalarMappable(cmap=cmap, norm=norm)
# sm.set_array([])
cmap, sm = create_color_map(color_matrix)

rows, cols = len(matrix), len(matrix[0])

fig, ax = plt.subplots()

# Draw horizontal lines
for i in range(rows + 1):
ax.axhline(i, color='black')

# Draw vertical lines
for i in range(cols + 1):
ax.axvline(i, color='black')

# Display numbers in the grid with at most 3 digits and fill cells based on color_matrix
for i in reversed(range(rows)):
for j in reversed(range(cols)):
number = matrix[i][j]
formatted_number = '{:.3f}'.format(number)[:6] # Ensure at most 3 digits before the decimal point

error_value = error_matrix[i][j]
color = get_color_for_value(error_value, cmap, np.min(color_matrix), np.max(color_matrix))

ax.add_patch(plt.Rectangle((j, i), 1, 1, fill=True, color=color))
ax.text(j + 0.5, i + 0.5, formatted_number, ha='center', va='center', fontsize=12, color='white')

ax.set_xlim(0, cols)
ax.set_ylim(0, rows)
ax.set_xticks([]) # Remove x-axis ticks
ax.set_yticks([]) # Remove y-axis ticks
ax.invert_yaxis()

# Add color bar
cbar = plt.colorbar(sm, ax=ax, orientation='vertical', pad=0.02)
cbar.set_label('Absolute Error', rotation=270, labelpad=15)
cbar.set_ticks([np.min(color_matrix), np.max(color_matrix)])
cbar.set_ticklabels(['{:.3f}'.format(np.min(color_matrix)), '{:.3f}'.format(np.max(color_matrix))])

# Add label above the matrix (centered with the matrix)
plt.text(cols / 2, rows + 0.5, 'Predicted Values', ha='center', va='center', fontsize=14)

plt.grid(False)

# Save the plot as a PDF file
plt.savefig(filename, format='pdf')

# Show the plot (if needed)
# plt.show()

def main():
color_matrix = None
base_path = "./train-results/predicted-data-vis/"

data_files = ["burn_6.csv", "burn_11.csv", "burn_16.csv", "burn_21.csv"]
for file in data_files:
df = pd.read_csv(base_path + file)
df['absolute_error'] = abs(df['Actual'] - df['Predicted'])

actual = np.array(df['Actual']).reshape(6, 6)
predicted = np.array(df['Predicted']).reshape(6, 6)
abs_error = np.array(df['absolute_error']).reshape(6, 6)

# set the color scale to have the same range for all plots
if file == "burn_6.csv":
color_matrix = abs_error

draw_grid(predicted, abs_error, color_matrix, filename=f'{base_path}{file.replace(".csv", "")}.pdf')

if __name__ == '__main__':
main()
37 changes: 37 additions & 0 deletions train-results/predicted-data-vis/burn_11.csv
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
Actual,Predicted
33.00999999999999,2.441566
0.13,1.287153
30.04,9.662184
0.04,1.4504021
0.0,0.57912457
0.0,1.5314125
0.02,0.33348063
4.25,4.697654
0.18,1.2343149
0.04,1.4911749
0.0,0.50423425
0.01,0.71346486
0.0,2.164581
0.0,1.8728352
1.0,0.56869173
0.19,1.6881206
0.0,1.2043515
0.0,0.57622427
0.46,0.62170124
0.08,0.12295965
0.67,0.73713386
0.81,1.0150232
0.28,0.786072
0.03,3.3052664
0.9599999999999999,1.2515504
0.1,0.37647727
0.12,0.77600485
0.04,0.7376241
0.39,1.1295924
0.0,1.2982188
0.04,1.1077411
0.67,2.7119973
0.03,0.8327027
0.09,1.1176233
12.039999999999997,7.7226386
0.02,1.3766172
Binary file added train-results/predicted-data-vis/burn_11.pdf
Binary file not shown.
37 changes: 37 additions & 0 deletions train-results/predicted-data-vis/burn_16.csv
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
Actual,Predicted
0.03,0.09476566
0.0,0.14770453
0.0,0.028910527
0.0,0.16033459
0.0,0.3812819
0.0,0.029445633
0.0,0.12857425
0.0,0.34443325
0.0,0.5202863
0.0,0.30236974
0.0,0.15392835
0.0,0.011927954
0.0,0.21076563
0.0,0.12555586
0.0,0.22733401
0.0,0.31207517
0.0,0.086362086
0.0,0.04986767
0.0,0.39596963
1.2,0.2759399
0.0,0.09705978
0.03,0.21685342
0.0,0.09518449
0.0,0.08952272
0.0,0.079145834
0.0,0.041401234
0.0,0.02609378
0.0,0.33272585
0.0,0.17639832
0.0,0.09980161
0.0,0.2981202
0.0,0.035559975
0.0,0.01701772
0.0,0.21755119
0.0,0.07301146
0.0,0.08434749
Binary file added train-results/predicted-data-vis/burn_16.pdf
Binary file not shown.
37 changes: 37 additions & 0 deletions train-results/predicted-data-vis/burn_21.csv
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
Actual,Predicted
0.0,0.25413516
0.0,0.089969926
0.0,0.10162594
0.0,0.16334783
0.0,0.100038044
0.0,0.058781154
0.0,0.15500419
0.0,0.21634062
0.0,0.40450987
0.0,0.074236505
0.0,0.06423112
0.0,0.009109741
0.0,0.11614813
0.0,0.16305806
0.0,0.32475373
0.0,0.119550474
0.0,0.048834827
0.0,0.019283405
0.0,0.4140485
0.0,0.21384089
0.0,0.13307574
0.0,0.34860623
0.0,0.15205687
0.0,0.029667513
0.0,0.1426917
0.0,0.12611139
0.0,0.2047455
0.0,0.23884013
0.0,0.06072749
0.0,0.0063803126
0.02,0.2802802
0.0,0.6642211
0.0,0.07595109
0.0,0.20884931
0.0,0.08108409
0.0,0.021940004
Binary file added train-results/predicted-data-vis/burn_21.pdf
Binary file not shown.
37 changes: 37 additions & 0 deletions train-results/predicted-data-vis/burn_6.csv
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
Actual,Predicted
0.20000000000000004,0.12333784
6.01,1.7319317
0.22,0.69740134
0.01,0.14848521
0.0,0.036329575
0.0,0.07671892
0.0,0.02537092
0.8400000000000001,0.20883751
7.279999999999999,0.46068564
0.0,0.02859186
0.0,0.057903297
0.0,0.042104594
0.02,0.06016002
0.0,0.47121012
0.02,0.14980581
0.03,0.08425264
0.0,0.052084606
0.0,0.040004183
0.22,0.17706552
0.0,0.13211879
0.0,0.16416736
0.63,0.2605115
0.02,0.013563856
0.01,0.038166553
3.12,0.8644007
0.02,0.3797151
0.04,0.15504032
4.4,1.4288893
0.01,0.0021220404
0.0,0.034538448
0.0,0.32518405
298.0400000000001,27.212484
0.02,0.5952123
0.01,0.44371238
0.0,0.22072323
0.0,0.05316649
Binary file added train-results/predicted-data-vis/burn_6.pdf
Binary file not shown.