Skip to content

Commit

Permalink
inject pose during training
Browse files Browse the repository at this point in the history
  • Loading branch information
Dingel321 committed Jun 10, 2024
1 parent e4e4d51 commit 6405bca
Show file tree
Hide file tree
Showing 2 changed files with 36 additions and 10 deletions.
17 changes: 9 additions & 8 deletions src/cryo_sbi/inference/models/estimator_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ def __init__(

self.npe = NPE(
1,
output_embedding_dim,
output_embedding_dim+4, # 4 for the pose in quaternion
transforms=num_transforms,
build=flow,
hidden_features=[*[hidden_flow_dim] * num_hidden_flow, 128, 64],
Expand All @@ -105,7 +105,7 @@ def __init__(
self.embedding = embedding_net()
self.standardize = Standardize(theta_shift, theta_scale)

def forward(self, theta: torch.Tensor, x: torch.Tensor) -> torch.Tensor:
def forward(self, theta: torch.Tensor, x: torch.Tensor, pose: torch.Tensor) -> torch.Tensor:
"""
Forward pass of the NPE model
Expand All @@ -116,10 +116,10 @@ def forward(self, theta: torch.Tensor, x: torch.Tensor) -> torch.Tensor:
Returns:
torch.Tensor: Log probability of the posterior.
"""

return self.npe(self.standardize(theta), torch.cat([self.embedding(x), pose], dim=1))

return self.npe(self.standardize(theta), self.embedding(x))

def flow(self, x: torch.Tensor):
def flow(self, x: torch.Tensor, pose: torch.Tensor):
"""
Conditions the posterior on an image.
Expand All @@ -129,9 +129,10 @@ def flow(self, x: torch.Tensor):
Returns:
zuko.flows.Flow: The posterior distribution.
"""
return self.npe.flow(self.embedding(x))

def sample(self, x: torch.Tensor, shape=(1,)) -> torch.Tensor:
return self.npe.flow(torch.cat([self.embedding(x), pose], dim=1))

def sample(self, x: torch.Tensor, pose: torch.Tensor, shape=(1,)) -> torch.Tensor:
"""
Generate samples from the posterior distribution.
Expand All @@ -143,5 +144,5 @@ def sample(self, x: torch.Tensor, shape=(1,)) -> torch.Tensor:
torch.Tensor: Samples from the posterior distribution.
"""

samples_standardized = self.flow(x).sample(shape)
samples_standardized = self.flow(x, pose).sample(shape)
return self.standardize.transform(samples_standardized)
29 changes: 27 additions & 2 deletions src/cryo_sbi/inference/train_npe_model.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from typing import Union
import json
import torch
import torch.nn as nn
import numpy as np
import torch.optim as optim
from torch.utils.data import TensorDataset
Expand Down Expand Up @@ -44,6 +45,28 @@ def load_model(
return estimator


class NPELoss_pose(nn.Module):

def __init__(self, estimator: nn.Module):
super().__init__()

self.estimator = estimator

def forward(self, theta: torch.Tensor, x: torch.Tensor, pose: torch.Tensor) -> torch.Tensor:
r"""
Arguments:
theta: The parameters :math:`\theta`, with shape :math:`(N, D)`.
x: The observation :math:`x`, with shape :math:`(N, L)`.
Returns:
The scalar loss :math:`l`.
"""

log_p = self.estimator(theta, x, pose)

return -log_p.mean()


def npe_train_no_saving(
image_config: str,
train_config: str,
Expand Down Expand Up @@ -115,7 +138,7 @@ def npe_train_no_saving(
train_config, model_state_dict, device, train_from_checkpoint
)

loss = NPELoss(estimator)
loss = NPELoss_pose(estimator)
optimizer = optim.AdamW(
estimator.parameters(), lr=train_config["LEARNING_RATE"], weight_decay=0.001
)
Expand Down Expand Up @@ -151,15 +174,17 @@ def npe_train_no_saving(
num_pixels,
pixel_size,
)
for _indices, _images in zip(
for _indices, _images, _quaternions in zip(
indices.split(train_config["BATCH_SIZE"]),
images.split(train_config["BATCH_SIZE"]),
quaternions.split(train_config["BATCH_SIZE"])
):
losses.append(
step(
loss(
_indices.to(device, non_blocking=True),
_images.to(device, non_blocking=True),
_quaternions.to(device, non_blocking=True)
)
)
)
Expand Down

0 comments on commit 6405bca

Please sign in to comment.