Skip to content

Commit

Permalink
added more flow models
Browse files Browse the repository at this point in the history
  • Loading branch information
Dingel321 committed Sep 11, 2024
1 parent 7ffcba3 commit 70e5a01
Showing 1 changed file with 7 additions and 1 deletion.
8 changes: 7 additions & 1 deletion src/cryo_sbi/inference/models/build_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,16 @@ def build_npe_flow_model(config: dict, **embedding_kwargs) -> nn.Module:

if config["MODEL"] == "MAF":
model = zuko.flows.MAF
elif config["MODEL"] == "GF":
model = zuko.flows.GF
elif config["MODEL"] == "CNF":
model = zuko.flows.CNF
elif config["MODEL"] == "UMNN":
model = zuko.flows. UMNN
elif config["MODEL"] == "NSF":
model = zuko.flows.NSF
elif config["MODEL"] == "SOSPF":
model = partial(zuko.flows.SOSPF, polynomials=16, degree=5)
model = partial(zuko.flows.SOSPF, polynomials=8, degree=5)
else:
raise NotImplementedError(
f"Model : {config['MODEL']} has not been implemented yet!"
Expand Down

0 comments on commit 70e5a01

Please sign in to comment.