From bd582872647b9a2201d20c2f5989e801763c79af Mon Sep 17 00:00:00 2001 From: Denise Date: Tue, 24 Oct 2023 16:55:29 +0200 Subject: [PATCH 1/2] add mae loss function --- sbi_lens/normflow/train_model.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/sbi_lens/normflow/train_model.py b/sbi_lens/normflow/train_model.py index 8ca4cd1..0a42cee 100644 --- a/sbi_lens/normflow/train_model.py +++ b/sbi_lens/normflow/train_model.py @@ -18,6 +18,13 @@ def loss_mse(self, params, theta, x, state_resnet): return loss, opt_state_resnet + def loss_mae(self, params, theta, x, state_resnet): + y, opt_state_resnet = self.compressor.apply(params, state_resnet, None, x) + + loss = jnp.mean(jnp.sum(jnp.absolute(y - theta), axis=1)) + + return loss, opt_state_resnet + def loss_vmim(self, params, theta, x, state_resnet): y, opt_state_resnet = self.compressor.apply(params, state_resnet, None, x) log_prob = self.nf.apply(params, theta, y) From 82b696853431b270cdba89958911bb78e3aeedee Mon Sep 17 00:00:00 2001 From: Denise Date: Tue, 24 Oct 2023 16:59:42 +0200 Subject: [PATCH 2/2] script updated --- scripts/train_compressor.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/scripts/train_compressor.py b/scripts/train_compressor.py index f95c051..83baef5 100644 --- a/scripts/train_compressor.py +++ b/scripts/train_compressor.py @@ -122,6 +122,8 @@ def __call__(self, y): parameters_compressor = hk.data_structures.merge(parameters_resnet, params_nf) elif args.loss == "train_compressor_mse": parameters_compressor = parameters_resnet +elif args.loss == "train_compressor_mae": + parameters_compressor = parameters_resnet elif args.loss == "train_compressor_gnll": parameters_compressor = parameters_resnet @@ -197,6 +199,8 @@ def __call__(self, y): l_name = "vmim" elif args.loss == "train_compressor_mse": l_name = "mse" +elif args.loss == "train_compressor_mae": + l_name = "mae" elif args.loss == "train_compressor_gnll": l_name = "gnll"