-
Notifications
You must be signed in to change notification settings - Fork 1
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
add gnll loss #64
add gnll loss #64
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Review mostly about Pythonic stuff.
sbi_lens/normflow/train_model.py
Outdated
|
||
class TrainModel: | ||
|
||
class TrainModelLocal: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ok but new name not used in training script
sbi_lens/normflow/train_model.py
Outdated
@@ -35,21 +56,29 @@ def __init__( | |||
loss_name, | |||
nb_pixels, | |||
nb_bins, | |||
dim=None, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
New argument to the class defined here, but should also be added to a class docstring somehow to understand what it is used for.
sbi_lens/normflow/train_model.py
Outdated
elif loss_name == "loss_for_sbi": | ||
if info_compressor is None: | ||
raise NotImplementedError | ||
raise ValueError("sbi loss needs compressor informations") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
raise ValueError("sbi loss needs compressor informations") | |
raise ValueError("sbi loss needs compressor information") |
sbi_lens/normflow/train_model.py
Outdated
|
||
if loss_name == "train_compressor_mse": | ||
self.loss = self.loss_mse | ||
elif loss_name == "train_compressor_vmim": | ||
self.loss = self.loss_vmim | ||
elif loss_name == "train_compressor_gnll": |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could loss_name
be shortened to "mse"
, "vmin"
and "gnll"
?
If so it should be refactored
scripts/train_compressor.py
Outdated
|
||
if args.loss == "train_compressor_gnll": | ||
compressor = hk.transform_with_state( | ||
lambda y: ResNet18(int(dim + ((dim**2) - dim) / 2 + dim))(y, is_training=True) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The only difference in your if statement is the input dimension of the ResNet18. Then it should look like
if args.loss == "train_compressor_gnll":
# Comment on how the dim is computed
resnet_dim = int(dim + ((dim**2) - dim) / 2 + dim)
else:
resnet_dim = dim
compressor = hk.transform_with_state(lambda y: ResNet18(resnet_dim)(y, is_training=True))
scripts/train_compressor.py
Outdated
elif args.loss == "train_compressor_gnll": | ||
parameters_compressor = parameters_resnet |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could be simplified to args.loss in ["train_compressor_mse", "train_compressor_gnll"]
start_lr = 0.0001 | ||
|
||
else: | ||
start_lr = 0.001 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You could explicitely define at the top of the training script a
DEFAULT_LEARNING_RATE = 0.001
constant value so this value is not hidden at the bottom of the script.
Same for the BATCH_SIZE for instance, even if you don't modify it.
I add the Gaussian negative log likelihood loss (from https://arxiv.org/pdf/1906.03156.pdf) and changed the train_compressor script accordingly