Skip to content

Commit

Permalink
integrate comments
Browse files Browse the repository at this point in the history
  • Loading branch information
Justinezgh committed Nov 21, 2023
1 parent 910a933 commit 8e814e0
Show file tree
Hide file tree
Showing 2 changed files with 51 additions and 48 deletions.
85 changes: 46 additions & 39 deletions sbi_lens/normflow/train_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,35 +11,64 @@


class TrainModel:
def __init__(
self,
compressor,
nf,
optimizer,
loss_name,
dim=None,
info_compressor=None,
):
self.compressor = compressor
self.nf = nf
self.optimizer = optimizer
self.dim = dim # summary statistic dimension

if loss_name == "train_compressor_mse":
self.loss = self.loss_mse
elif loss_name == "train_compressor_vmim":
self.loss = self.loss_vmim
elif loss_name == "train_compressor_gnll":
self.loss = self.loss_gnll
if self.dim is None:
raise ValueError("dim should be specified when using gnll compressor")
elif loss_name == "loss_for_sbi":
if info_compressor is None:
raise ValueError("sbi loss needs compressor informations")
else:
self.info_compressor = info_compressor
self.loss = self.loss_nll

def loss_mse(self, params, theta, x, state_resnet):
"""Compute the Mean Squared Error loss
"""
y, opt_state_resnet = self.compressor.apply(params, state_resnet, None, x)

loss = jnp.mean(jnp.sum((y - theta) ** 2, axis=1))

return loss, opt_state_resnet

def loss_mae(self, params, theta, x, state_resnet):
"""Compute the Mean Absolute Error loss
"""
y, opt_state_resnet = self.compressor.apply(params, state_resnet, None, x)

loss = jnp.mean(jnp.sum(jnp.absolute(y - theta), axis=1))

return loss, opt_state_resnet

def loss_vmim(self, params, theta, x, state_resnet):
"""Compute the Variational Mutual Information Maximization loss
"""
y, opt_state_resnet = self.compressor.apply(params, state_resnet, None, x)
log_prob = self.nf.apply(params, theta, y)

return -jnp.mean(log_prob), opt_state_resnet

def loss_nll(self, params, theta, x, _):
y, _ = self.compressor.apply(
self.info_compressor[0], self.info_compressor[1], None, x
)
log_prob = self.nf.apply(params, theta, y)

return -jnp.mean(log_prob), _

def loss_gnll(self, params, theta, x, state_resnet):
"""Compute the Gaussian Negative Log Likelihood loss
"""
y, opt_state_resnet = self.compressor.apply(params, state_resnet, None, x)
y_mean = y[..., : self.dim]
y_var = y[..., self.dim :]
Expand All @@ -55,38 +84,16 @@ def _get_log_prob(y_mean, y_var, theta):

return loss, opt_state_resnet

def __init__(
self,
compressor,
nf,
optimizer,
loss_name,
nb_pixels,
nb_bins,
dim=None,
info_compressor=None,
):
self.compressor = compressor
self.nf = nf
self.optimizer = optimizer
self.nb_pixels = nb_pixels
self.nb_bins = nb_bins
self.dim = dim
def loss_nll(self, params, theta, x, _):
"""Compute the Negative Log Likelihood loss.
This loss is for inference so it requires to have a trained compressor.
"""
y, _ = self.compressor.apply(
self.info_compressor[0], self.info_compressor[1], None, x
)
log_prob = self.nf.apply(params, theta, y)

if loss_name == "train_compressor_mse":
self.loss = self.loss_mse
elif loss_name == "train_compressor_vmim":
self.loss = self.loss_vmim
elif loss_name == "train_compressor_gnll":
self.loss = self.loss_gnll
if self.dim is None:
raise ValueError("dim should be specified when using gnll compressor")
elif loss_name == "loss_for_sbi":
if info_compressor is None:
raise ValueError("sbi loss needs compressor informations")
else:
self.info_compressor = info_compressor
self.loss = self.loss_nll
return -jnp.mean(log_prob), _

@partial(jax.jit, static_argnums=(0,))
def update(self, model_params, opt_state, theta, x, state_resnet=None):
Expand Down
14 changes: 5 additions & 9 deletions scripts/train_compressor.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,11 +101,11 @@ def __call__(self, y):
)

if args.loss == "train_compressor_gnll":
compressor = hk.transform_with_state(
lambda y: ResNet18(int(dim + ((dim**2) - dim) / 2 + dim))(y, is_training=True)
)
resnet_dim = int(dim + ((dim**2) - dim) / 2 + dim)
else:
compressor = hk.transform_with_state(lambda y: ResNet18(dim)(y, is_training=True))
resnet_dim = dim

compressor = hk.transform_with_state(lambda y: ResNet18(resnet_dim)(y, is_training=True))

print("######## TRAIN ########")

Expand All @@ -120,11 +120,7 @@ def __call__(self, y):

if args.loss == "train_compressor_vmim":
parameters_compressor = hk.data_structures.merge(parameters_resnet, params_nf)
elif args.loss == "train_compressor_mse":
parameters_compressor = parameters_resnet
elif args.loss == "train_compressor_mae":
parameters_compressor = parameters_resnet
elif args.loss == "train_compressor_gnll":
elif args.loss in ["train_compressor_mse", "train_compressor_mae", "train_compressor_gnll"]:
parameters_compressor = parameters_resnet


Expand Down

0 comments on commit 8e814e0

Please sign in to comment.