You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Normally lambdas are used for the where_input and where_state fields of ModelStage objects, which are specified inside of the model_spec of an AbstractStagedModel.
Because lambdas do not have nice string representations, when we print an instance's model_spec, we don't learn anything about its stages but their names. For example, here's a printout of model_spec for an instance of SimpleStagedNetwork:
The function feedbax.pprint_model_spec is a little better, since it provides info about the identity of the callable:
hidden: GRUCell
readout: wrapped: Linear
out_nonlinearity: wrapped: identity_func
Similarly to what's done with WhereDict (#14) it would be nice to parse the lambdas where_input and where_state so we can pretty print the references they contain. In this case, that could look something like
In this case I've left in <lambda> for the input to GRUCell, which is ravel_pytree(input)[0]. Of course in case like this where the lambdas aren't just references to parts of the input/state, it's less clear how they should be included in the printout.
It may be unnecessary to provide this information to the user in this form, since they could just refer to the source of model_spec. One downside of expecting the user to read the source is that some model_spec definitions may use arbitrary logic to decide whether to include certain stages or not, which the user would need to parse. For example, in SimpleStagedNetwork:
On the other hand, a function like pprint_model_spec shows exactly the components in a model instance, briefly, in the order they are actually included.
The text was updated successfully, but these errors were encountered:
Normally lambdas are used for the
where_input
andwhere_state
fields ofModelStage
objects, which are specified inside of themodel_spec
of anAbstractStagedModel
.Because
lambdas
do not have nice string representations, when we print an instance'smodel_spec
, we don't learn anything about its stages but their names. For example, here's a printout ofmodel_spec
for an instance ofSimpleStagedNetwork
:The function
feedbax.pprint_model_spec
is a little better, since it provides info about the identity of the callable:Similarly to what's done with
WhereDict
(#14) it would be nice to parse the lambdaswhere_input
andwhere_state
so we can pretty print the references they contain. In this case, that could look something likeIn this case I've left in
<lambda>
for the input toGRUCell
, which isravel_pytree(input)[0]
. Of course in case like this where the lambdas aren't just references to parts of the input/state, it's less clear how they should be included in the printout.It may be unnecessary to provide this information to the user in this form, since they could just refer to the source of
model_spec
. One downside of expecting the user to read the source is that somemodel_spec
definitions may use arbitrary logic to decide whether to include certain stages or not, which the user would need to parse. For example, inSimpleStagedNetwork
:feedbax/feedbax/nn.py
Lines 234 to 330 in 2ce8b1c
On the other hand, a function like
pprint_model_spec
shows exactly the components in a model instance, briefly, in the order they are actually included.The text was updated successfully, but these errors were encountered: