-
Notifications
You must be signed in to change notification settings - Fork 2
Open
Description
Both the torch and the jax pipeline have the following decision rule applied for the network_type setting.
Torch:
if network_type is not None:
self.network_type = network_type
else:
self.network_type = "lan" if self.train_output_type == "logprob" else "cpn"
print(
'Setting network type to "lan" or "cpn" based on train_output_type. \n'
+ "Note: This is only a default setting, and can be overwritten by the network_type argument."
)
Jax:
# Identify network type:
if self.model.train_output_type == "logprob":
network_type = "lan"
elif self.model.train_output_type == "logits":
network_type = "cpn"
else:
network_type = "unknown"
print(
'Model type identified as "unknown" because '
"the training_output_type attribute"
' of the supplied jax model is neither "logprob", nor "logits"'
)
Anchoring on the train_output_type is the basic mistake. Since both the cpn and the opn networks have train_output_type="logits".
This needs a slight refactor with new logic. Possibly not "identify"/"deduce" a name for the network_type and instead just pass that actively.
Note:
For backward compatibility, the previous (current) logic might need to be kept around.
Reactions are currently unavailable
Metadata
Metadata
Assignees
Labels
No labels