-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathVisualize_training_predictions.py
More file actions
320 lines (260 loc) · 11.9 KB
/
Visualize_training_predictions.py
File metadata and controls
320 lines (260 loc) · 11.9 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
"""
Visualize what the diffusion model predicts during training.
Shows actual noise vs predicted noise/v-param to understand model behavior.
This helps diagnose:
- Whether model predictions are spatially coherent or just averaging
- If the model is learning meaningful denoising patterns
- Differences between eps-param and v-param predictions
"""
import torch
import numpy as np
import matplotlib.pyplot as plt
import yaml
import os
from DIFFUSION import LatentDiffusion
from UNET import UNetModel
from VAE import TemperatureVAE, SimpleConvEncoder, SimpleConvDecoder
from Conditional_dataset import ConditionalTemperatureDataModule
def load_model(checkpoint_path, vae_config_path, parameterization="eps"):
"""Load trained diffusion model"""
# Load VAE
with open(vae_config_path, 'r') as f:
vae_config = yaml.safe_load(f)
model_config = vae_config['model']
levels = int(model_config['levels'])
encoded_channels = int(model_config['encoded_channels'])
hidden_width = int(model_config['hidden_width'])
in_dim = int(model_config['in_dim'])
vae_encoder = SimpleConvEncoder(
in_dim=in_dim,
levels=levels,
channel_list=model_config['channel_list']
)
vae_decoder = SimpleConvDecoder(
in_dim=in_dim,
levels=levels,
channel_list=model_config['channel_list']
)
vae = TemperatureVAE(
encoder=vae_encoder,
decoder=vae_decoder,
kl_weight=float(model_config['kl_weight']),
encoded_channels=encoded_channels,
hidden_width=hidden_width,
)
checkpoint_path_vae = os.path.join(
vae_config['paths']['checkpoint_dir'],
vae_config['paths']['checkpoint_file']
)
state_dict = torch.load(checkpoint_path_vae, map_location="cpu")["state_dict"]
vae.load_state_dict(state_dict)
vae.eval()
# Create UNet
unet = UNetModel(
model_channels=128,
in_channels=vae_config['model']['hidden_width'],
out_channels=vae_config['model']['hidden_width'],
num_res_blocks=2,
dropout=0.0,
channel_mult=(1, 2, 2, 4),
dims=3,
use_checkpoint=False,
num_classes=12, # Monthly conditioning
)
# Create diffusion model
diffusion = LatentDiffusion(
model=unet,
autoencoder=vae,
timesteps=1000,
beta_schedule="linear",
parameterization=parameterization,
)
# Load checkpoint (strict=False to handle old VAE parameters)
if os.path.exists(checkpoint_path):
checkpoint = torch.load(checkpoint_path, map_location="cpu")
diffusion.load_state_dict(checkpoint["state_dict"], strict=False)
print(f"Loaded checkpoint from {checkpoint_path}")
else:
print(f"WARNING: No checkpoint found at {checkpoint_path}, using random weights")
diffusion.eval()
return diffusion
def visualize_training_step(diffusion_model, batch, save_dir="training_visualizations"):
"""Visualize what the model predicts vs actual noise during training"""
os.makedirs(save_dir, exist_ok=True)
x, y, class_labels = batch
# Encode to latent space
with torch.inference_mode():
y_encoded = diffusion_model.autoencoder.encoder(y)
moments = diffusion_model.autoencoder.to_moments(y_encoded)
if moments.shape[1] == diffusion_model.autoencoder.hidden_width:
z = moments
else:
mean, log_var = torch.chunk(moments, 2, dim=1)
std = torch.exp(0.5 * log_var)
eps = torch.randn_like(std)
z = mean + std * eps
# Whiten
z_whitened = (z - diffusion_model.z_mean) / diffusion_model.z_std
# Test at different timesteps
timesteps_to_test = [50, 250, 500, 750, 950] # Low to high noise
for t_val in timesteps_to_test:
# Create timestep tensor
t = torch.full((z_whitened.shape[0],), t_val, dtype=torch.long)
# Add noise (forward diffusion)
noise = torch.randn_like(z_whitened)
# Get alpha values
a_t = diffusion_model.sqrt_alphas_cumprod[t].view(-1, 1, 1, 1, 1)
a_t_minus_1 = diffusion_model.sqrt_one_minus_alphas_cumprod[t].view(-1, 1, 1, 1, 1)
# Noisy latent: x_t = sqrt(alpha_t) * x_0 + sqrt(1-alpha_t) * noise
x_noisy = a_t * z_whitened + a_t_minus_1 * noise
# Model prediction
with torch.no_grad():
model_output = diffusion_model.model(x_noisy, t, class_labels)
# Calculate predicted x_0 from model output (for eps-param)
if diffusion_model.parameterization == "eps":
# x̂_0 = (x_t - √(1-α_t) * ε̂) / √α_t
x_0_pred = (x_noisy - a_t_minus_1 * model_output) / a_t
# Calculate MSE in x_0 space
x_0_mse = ((x_0_pred - z_whitened) ** 2).mean().item()
elif diffusion_model.parameterization == "v":
# For v-param: x_0 = √α_t * x_t - √(1-α_t) * v
x_0_pred = a_t * x_noisy - a_t_minus_1 * model_output
x_0_mse = ((x_0_pred - z_whitened) ** 2).mean().item()
else:
x_0_pred = None
x_0_mse = None
# Visualize first sample in batch
sample_idx = 0
# Get 2D slices (middle of Z dimension)
z_mid = z_whitened.shape[2] // 2
original = z_whitened[sample_idx, 0, z_mid].cpu().numpy()
noise_added = noise[sample_idx, 0, z_mid].cpu().numpy()
noisy_input = x_noisy[sample_idx, 0, z_mid].cpu().numpy()
predicted = model_output[sample_idx, 0, z_mid].cpu().numpy()
x_0_pred_slice = x_0_pred[sample_idx, 0, z_mid].cpu().numpy() if x_0_pred is not None else None
# Create visualization
fig, axes = plt.subplots(2, 4, figsize=(20, 10))
# Row 1: Process
im0 = axes[0, 0].imshow(original, cmap='RdBu_r', vmin=-2, vmax=2)
axes[0, 0].set_title(f'Original Latent x₀ (whitened)')
plt.colorbar(im0, ax=axes[0, 0])
im1 = axes[0, 1].imshow(noise_added, cmap='RdBu_r', vmin=-3, vmax=3)
axes[0, 1].set_title(f'Actual Noise Added')
plt.colorbar(im1, ax=axes[0, 1])
im2 = axes[0, 2].imshow(noisy_input, cmap='RdBu_r', vmin=-3, vmax=3)
axes[0, 2].set_title(f'Noisy Input (t={t_val})')
plt.colorbar(im2, ax=axes[0, 2])
# Show predicted x_0
if x_0_pred_slice is not None:
im_x0 = axes[0, 3].imshow(x_0_pred_slice, cmap='RdBu_r', vmin=-2, vmax=2)
axes[0, 3].set_title(f'Predicted x̂₀')
plt.colorbar(im_x0, ax=axes[0, 3])
else:
axes[0, 3].axis('off')
# Row 2: Predictions
if diffusion_model.parameterization == "eps":
im3 = axes[1, 0].imshow(predicted, cmap='RdBu_r', vmin=-3, vmax=3)
axes[1, 0].set_title(f'Model Predicted Noise (ε̂)')
plt.colorbar(im3, ax=axes[1, 0])
error = predicted - noise_added
im4 = axes[1, 1].imshow(error, cmap='RdBu_r', vmin=-1, vmax=1)
axes[1, 1].set_title(f'ε-space Error')
plt.colorbar(im4, ax=axes[1, 1])
# Show x_0 error
x_0_error = x_0_pred_slice - original
im5 = axes[1, 2].imshow(x_0_error, cmap='RdBu_r', vmin=-1, vmax=1)
axes[1, 2].set_title(f'x₀-space Error')
plt.colorbar(im5, ax=axes[1, 2])
elif diffusion_model.parameterization == "v":
# v-prediction: v = sqrt(alpha_t) * noise - sqrt(1-alpha_t) * x_0
# Extract scalar alpha values for this batch element
a_t_scalar = a_t[sample_idx, 0, 0, 0, 0].item()
a_t_minus_1_scalar = a_t_minus_1[sample_idx, 0, 0, 0, 0].item()
actual_v = a_t_scalar * noise_added - a_t_minus_1_scalar * original
im3 = axes[1, 0].imshow(predicted, cmap='RdBu_r', vmin=-3, vmax=3)
axes[1, 0].set_title(f'Model Predicted v')
plt.colorbar(im3, ax=axes[1, 0])
error = predicted - actual_v
im4 = axes[1, 1].imshow(error, cmap='RdBu_r', vmin=-1, vmax=1)
axes[1, 1].set_title(f'v-space Error')
plt.colorbar(im4, ax=axes[1, 1])
# Show x_0 error
x_0_error = x_0_pred_slice - original
im5 = axes[1, 2].imshow(x_0_error, cmap='RdBu_r', vmin=-1, vmax=1)
axes[1, 2].set_title(f'x₀-space Error')
plt.colorbar(im5, ax=axes[1, 2])
# Statistics
eps_mse = ((predicted - (noise_added if diffusion_model.parameterization == 'eps' else actual_v))**2).mean()
# Calculate sensitivity: ∂x₀/∂ε = -√(1-α_t)/√α_t
alpha_t = a_t[sample_idx, 0, 0, 0, 0].item() ** 2 # sqrt_alpha -> alpha
sensitivity = -(1 - alpha_t)**0.5 / alpha_t**0.5
stats_text = f"""
Timestep: {t_val} / 1000
Parameterization: {diffusion_model.parameterization}
Original std: {original.std():.3f}
Noise std: {noise_added.std():.3f}
Noisy std: {noisy_input.std():.3f}
Predicted std: {predicted.std():.3f}
ε-space MSE: {eps_mse:.4f}
x₀-space MSE: {x_0_mse:.4f}
Sensitivity (∂x₀/∂ε): {sensitivity:.3f}
Expected x₀ error: {(sensitivity * eps_mse**0.5):.4f}
"""
axes[1, 3].text(0.1, 0.5, stats_text, fontsize=10, verticalalignment='center', family='monospace')
axes[1, 3].axis('off')
plt.suptitle(f'Training Prediction Visualization - t={t_val}, Month={class_labels[sample_idx].item()}', fontsize=14)
plt.tight_layout()
filename = f"{save_dir}/training_t{t_val:04d}_month{class_labels[sample_idx].item()}.png"
plt.savefig(filename, dpi=150, bbox_inches='tight')
print(f"Saved: {filename}")
plt.close()
print(f"\n✅ Saved {len(timesteps_to_test)} training visualizations to {save_dir}/")
def main():
"""Main function"""
print("="*80)
print("TRAINING PREDICTION VISUALIZER")
print("="*80)
# Configuration - HARDCODED FOR HPC (no interactive input)
vae_config_path = "/leonardo_scratch/fast/CNHPC_1990904/fbattini/scripts/VAE_config.yaml"
# Set which model to visualize here:
USE_V_PARAM = False # Change to False for eps-parameterization
if USE_V_PARAM:
checkpoint_path = "checkpoints_diffusion/conditional_monthly_1ch/conditional-monthly-diffusion-epoch=19-val_loss=0.0515.ckpt"
parameterization = "v"
save_dir = "visualizations/training_v_param"
print("\n📊 Visualizing v-parameterization model")
else:
checkpoint_path = "checkpoints_diffusion/conditional_monthly_1ch/last.ckpt"
parameterization = "eps"
save_dir = "visualizations/training_eps_param"
print("\n📊 Visualizing eps-parameterization model")
# Load model
print(f"\n📦 Loading {parameterization}-parameterization model...")
diffusion = load_model(checkpoint_path, vae_config_path, parameterization)
# Load data
print("📦 Loading data...")
# Load config to get data path
from VAE_config_utils import load_config
config = load_config(vae_config_path)
data_config = config['data']
data_module = ConditionalTemperatureDataModule(
data_path=data_config['data_path'],
start_year=int(data_config['start_year']),
end_year=int(data_config['end_year']),
val_year=int(data_config['val_year']),
batch_size=4,
num_workers=2,
)
data_module.setup('fit')
# Get a batch
batch = next(iter(data_module.val_dataloader()))
print(f"Batch shapes: x={batch[0].shape}, y={batch[1].shape}, labels={batch[2].shape}")
# Visualize
print(f"\n🎨 Creating visualizations...")
visualize_training_step(diffusion, batch, save_dir)
print("\n✅ Done! Check the visualizations to see:")
print(" - How well the model predicts noise at different timesteps")
print(" - Whether predictions are spatially coherent or averaging")
print(" - Differences between high-noise (t=950) and low-noise (t=50) timesteps")
if __name__ == "__main__":
main()