Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1 +1,7 @@
*.DS_Store

# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
*.so

62 changes: 59 additions & 3 deletions examples/train_sfno.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,13 @@
import wandb
wandb.login()

def l1_rel_error(truth, test):
batch_size = truth.shape[0]
difference = torch.zeros(batch_size)
for batch in range(batch_size):
difference[batch] = torch.mean(torch.abs(truth[batch] - test[batch]))/(torch.mean(torch.abs(truth[batch]))).item() * 100
return difference

def l2loss_sphere(solver, prd, tar, relative=False, squared=True):
loss = solver.integrate_grid((prd - tar)**2, dimensionless=True).sum(dim=-1)
if relative:
Expand Down Expand Up @@ -231,16 +238,56 @@ def log_weights_and_grads(model, iters=1):
store_dict = {'iteration': iters, 'grads': grad_dict, 'weights': weights_dict}
torch.save(store_dict, weights_and_grads_fname)


def plot_prediction_vs_target(prd, tar):
"""
Plots a 3x3 grid with predictions, targets, and their absolute difference.

Parameters:
phi (array-like): Azimuthal angle data (1D array).
theta (array-like): Polar angle data (1D array).
prd (array-like): Predicted values (shape: [n_points, 3]).
tar (array-like): Target values (shape: [n_points, 3]).
"""
fig, axes = plt.subplots(3, 3, figsize=(12, 10), constrained_layout=True)

# Compute absolute difference
diff = np.abs(prd - tar)

# Titles for rows
row_titles = ["Target", "Prediction", "Absolute Difference"]

for i in range(3): # Loop over rows: prd, tar, |prd-tar|
for j in range(3): # Loop over columns (channels)
ax = axes[j, i]
if i == 0:
contour = ax.contourf(tar[j, :], levels=100, cmap='viridis')
elif i == 1:
contour = ax.contourf(prd[j, :], levels=100, cmap='viridis')
else:
contour = ax.contourf(diff[j, :], levels=100, cmap='viridis')


cbar = fig.colorbar(contour, ax=ax, orientation='vertical')
cbar.ax.set_ylabel(f'Channel {j + 1}', rotation=270, labelpad=15)

ax.set_title(f"{row_titles[i]} - Channel {j + 1}")
ax.set_xlabel('Phi (Azimuthal Angle)')
ax.set_ylabel('Theta (Polar Angle)')

plt.savefig("sfno_prediction.png")
plt.close('all')

# training function
def train_model(model,
dataloader,
optimizer,
gscaler,
scheduler=None,
nepochs=20,
nepochs=200,
nfuture=0,
num_examples=256,
num_valid=8,
num_valid=64,
loss_fn='l2',
enable_amp=False,
log_grads=0):
Expand Down Expand Up @@ -307,27 +354,36 @@ def train_model(model,
# perform validation
valid_loss = 0
model.eval()
errors = torch.zeros((num_valid))
with torch.no_grad():
for inp, tar in dataloader:
for index, (inp, tar) in enumerate(dataloader):
prd = model(inp)
batch_size = inp.shape[0]
for _ in range(nfuture):
prd = model(prd)
loss = l2loss_sphere(solver, prd, tar, relative=True)

valid_loss += loss.item() * inp.size(0)
errors[batch_size*index:batch_size*(index+1)] = l1_rel_error(tar, prd)

if index == 0:
plot_prediction_vs_target(prd[0].cpu(), tar[0].cpu())

valid_loss = valid_loss / len(dataloader.dataset)

if scheduler is not None:
scheduler.step(valid_loss)



epoch_time = time.time() - epoch_start

print(f'--------------------------------------------------------------------------------')
print(f'Epoch {epoch} summary:')
print(f'time taken: {epoch_time}')
print(f'accumulated training loss: {acc_loss}')
print(f'relative validation loss: {valid_loss}')
print(f'median relative error: {torch.median(errors).item()}')

if wandb.run is not None:
current_lr = optimizer.param_groups[0]['lr']
Expand Down
Loading