Skip to content

Commit

Permalink
Merge pull request #65 from DifferentiableUniverseInitiative/dlanzier…
Browse files Browse the repository at this point in the history
…i/add_mae

MAE loss function
  • Loading branch information
Justinezgh authored Nov 20, 2023
2 parents 42495c9 + 82b6968 commit 910a933
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 0 deletions.
7 changes: 7 additions & 0 deletions sbi_lens/normflow/train_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,13 @@ def loss_mse(self, params, theta, x, state_resnet):

return loss, opt_state_resnet

def loss_mae(self, params, theta, x, state_resnet):
y, opt_state_resnet = self.compressor.apply(params, state_resnet, None, x)

loss = jnp.mean(jnp.sum(jnp.absolute(y - theta), axis=1))

return loss, opt_state_resnet

def loss_vmim(self, params, theta, x, state_resnet):
y, opt_state_resnet = self.compressor.apply(params, state_resnet, None, x)
log_prob = self.nf.apply(params, theta, y)
Expand Down
4 changes: 4 additions & 0 deletions scripts/train_compressor.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,8 @@ def __call__(self, y):
parameters_compressor = hk.data_structures.merge(parameters_resnet, params_nf)
elif args.loss == "train_compressor_mse":
parameters_compressor = parameters_resnet
elif args.loss == "train_compressor_mae":
parameters_compressor = parameters_resnet
elif args.loss == "train_compressor_gnll":
parameters_compressor = parameters_resnet

Expand Down Expand Up @@ -197,6 +199,8 @@ def __call__(self, y):
l_name = "vmim"
elif args.loss == "train_compressor_mse":
l_name = "mse"
elif args.loss == "train_compressor_mae":
l_name = "mae"
elif args.loss == "train_compressor_gnll":
l_name = "gnll"

Expand Down

0 comments on commit 910a933

Please sign in to comment.