Skip to content

Dissociate CPN / OPN cleanly #25

@AlexanderFengler

Description

@AlexanderFengler

Both the torch and the jax pipeline have the following decision rule applied for the network_type setting.

Torch:

        if network_type is not None:
            self.network_type = network_type
        else:
            self.network_type = "lan" if self.train_output_type == "logprob" else "cpn"
            print(
                'Setting network type to "lan" or "cpn" based on train_output_type. \n'
                + "Note: This is only a default setting, and can be overwritten by the network_type argument."
            )

Jax:


        # Identify network type:
        if self.model.train_output_type == "logprob":
            network_type = "lan"
        elif self.model.train_output_type == "logits":
            network_type = "cpn"
        else:
            network_type = "unknown"
            print(
                'Model type identified as "unknown" because '
                "the training_output_type attribute"
                ' of the supplied jax model is neither "logprob", nor "logits"'
            )

Anchoring on the train_output_type is the basic mistake. Since both the cpn and the opn networks have train_output_type="logits".

This needs a slight refactor with new logic. Possibly not "identify"/"deduce" a name for the network_type and instead just pass that actively.

Note:

For backward compatibility, the previous (current) logic might need to be kept around.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions