Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add gnll loss #64

Merged
merged 10 commits into from
Mar 5, 2024
Merged

add gnll loss #64

merged 10 commits into from
Mar 5, 2024

Conversation

Justinezgh
Copy link
Collaborator

I add the Gaussian negative log likelihood loss (from https://arxiv.org/pdf/1906.03156.pdf) and changed the train_compressor script accordingly

Copy link
Collaborator

@aboucaud aboucaud left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Review mostly about Pythonic stuff.


class TrainModel:

class TrainModelLocal:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok but new name not used in training script

sbi_lens/normflow/train_model.py Show resolved Hide resolved
sbi_lens/normflow/train_model.py Show resolved Hide resolved
@@ -35,21 +56,29 @@ def __init__(
loss_name,
nb_pixels,
nb_bins,
dim=None,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

New argument to the class defined here, but should also be added to a class docstring somehow to understand what it is used for.

elif loss_name == "loss_for_sbi":
if info_compressor is None:
raise NotImplementedError
raise ValueError("sbi loss needs compressor informations")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
raise ValueError("sbi loss needs compressor informations")
raise ValueError("sbi loss needs compressor information")


if loss_name == "train_compressor_mse":
self.loss = self.loss_mse
elif loss_name == "train_compressor_vmim":
self.loss = self.loss_vmim
elif loss_name == "train_compressor_gnll":
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could loss_name be shortened to "mse", "vmin" and "gnll" ?

If so it should be refactored


if args.loss == "train_compressor_gnll":
compressor = hk.transform_with_state(
lambda y: ResNet18(int(dim + ((dim**2) - dim) / 2 + dim))(y, is_training=True)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The only difference in your if statement is the input dimension of the ResNet18. Then it should look like

if args.loss == "train_compressor_gnll":
    # Comment on how the dim is computed
    resnet_dim = int(dim + ((dim**2) - dim) / 2 + dim)
else:
    resnet_dim = dim
compressor = hk.transform_with_state(lambda y: ResNet18(resnet_dim)(y, is_training=True))

Comment on lines 125 to 126
elif args.loss == "train_compressor_gnll":
parameters_compressor = parameters_resnet
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could be simplified to args.loss in ["train_compressor_mse", "train_compressor_gnll"]

start_lr = 0.0001

else:
start_lr = 0.001
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You could explicitely define at the top of the training script a
DEFAULT_LEARNING_RATE = 0.001 constant value so this value is not hidden at the bottom of the script.

Same for the BATCH_SIZE for instance, even if you don't modify it.

@Justinezgh Justinezgh merged commit 668c3c3 into main Mar 5, 2024
4 checks passed
@Justinezgh Justinezgh deleted the u/Justinezgh/add_gnll branch March 5, 2024 12:16
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants