Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Pretty printing of inputs and outputs of model stages #18

Open
mlprt opened this issue Feb 25, 2024 · 0 comments
Open

Pretty printing of inputs and outputs of model stages #18

mlprt opened this issue Feb 25, 2024 · 0 comments
Labels
enhancement New feature or request

Comments

@mlprt
Copy link
Owner

mlprt commented Feb 25, 2024

Normally lambdas are used for the where_input and where_state fields of ModelStage objects, which are specified inside of the model_spec of an AbstractStagedModel.

Because lambdas do not have nice string representations, when we print an instance's model_spec, we don't learn anything about its stages but their names. For example, here's a printout of model_spec for an instance of SimpleStagedNetwork:

OrderedDict([('hidden',
              ModelStage(
                callable=<function <lambda>>,
                where_input=<function <lambda>>,
                where_state=<function <lambda>>,
                intervenors=None
              )),
             ('readout',
              ModelStage(
                callable=<function <lambda>>,
                where_input=<function <lambda>>,
                where_state=<function <lambda>>,
                intervenors=None
              )),
             ('out_nonlinearity',
              ModelStage(
                callable=<function <lambda>>,
                where_input=<function <lambda>>,
                where_state=<function <lambda>>,
                intervenors=None
              ))])

The function feedbax.pprint_model_spec is a little better, since it provides info about the identity of the callable:

hidden: GRUCell
readout: wrapped: Linear
out_nonlinearity: wrapped: identity_func

Similarly to what's done with WhereDict (#14) it would be nice to parse the lambdas where_input and where_state so we can pretty print the references they contain. In this case, that could look something like

hidden: GRUCell(<lambda>, state.hidden) -> state.hidden
readout: wrapped: Linear(state.hidden) -> state.output
out_nonlinearity: wrapped: identity_func(state.output) -> state.output

In this case I've left in <lambda> for the input to GRUCell, which is ravel_pytree(input)[0]. Of course in case like this where the lambdas aren't just references to parts of the input/state, it's less clear how they should be included in the printout.

It may be unnecessary to provide this information to the user in this form, since they could just refer to the source of model_spec. One downside of expecting the user to read the source is that some model_spec definitions may use arbitrary logic to decide whether to include certain stages or not, which the user would need to parse. For example, in SimpleStagedNetwork:

feedbax/feedbax/nn.py

Lines 234 to 330 in 2ce8b1c

@property
def model_spec(self) -> OrderedDict[str, ModelStage]:
"""Specifies the network model stages: layers, nonlinearities, and noise.
Only includes stages for the encoding layer, readout layer, hidden noise, and
hidden nonlinearity, if the user respectively requests them at the time of
construction.
!!! NOTE
Inspects the instantiated hidden layer to determine if it is a stateful
network (e.g. an RNN). If not (e.g. Linear), it wraps the layer so that
it plays well with the state-passing of `AbstractStagedModel`. This assumes
that stateful layers will take 2 positional arguments, and stateless layers
only 1.
"""
if n_positional_args(self.hidden) == 1:
hidden_module = lambda self: wrap_stateless_callable(self.hidden)
if isinstance(self.hidden, eqx.nn.Linear):
logger.warning(
"Network hidden layer is linear but no hidden "
"nonlinearity is defined"
)
else:
# #TODO: revert this!
# def tmp(self):
# def wrapper(input, state, *, key):
# return self.hidden(input, jnp.zeros_like(state))
# return wrapper
# hidden_module = lambda self: tmp(self)
hidden_module = lambda self: self.hidden
if self.encoder is None:
spec = OrderedDict(
{
"hidden": ModelStage(
callable=hidden_module,
where_input=lambda input, _: ravel_pytree(input)[0],
where_state=lambda state: state.hidden,
),
}
)
else:
spec = OrderedDict(
{
"encoder": ModelStage(
callable=lambda self: lambda input, state, *, key: self.encoder(
input
),
where_input=lambda input, _: ravel_pytree(input)[0],
where_state=lambda state: state.encoding,
),
"hidden": ModelStage(
callable=hidden_module,
where_input=lambda input, state: state.encoding,
where_state=lambda state: state.hidden,
),
}
)
if self.hidden_nonlinearity is not None:
spec |= {
"hidden_nonlinearity": ModelStage(
callable=lambda self: wrap_stateless_callable(
self.hidden_nonlinearity, pass_key=False
),
where_input=lambda input, state: state.hidden,
where_state=lambda state: state.hidden,
),
}
if self.hidden_noise_std is not None:
spec |= {
"hidden_noise": ModelStage(
callable=lambda self: self._add_hidden_noise,
where_input=lambda input, state: state.hidden,
where_state=lambda state: state.hidden,
),
}
if self.readout is not None:
spec |= {
"readout": ModelStage(
callable=lambda self: wrap_stateless_callable(self.readout),
where_input=lambda input, state: state.hidden,
where_state=lambda state: state.output,
),
"out_nonlinearity": ModelStage(
callable=lambda self: wrap_stateless_callable(
self.out_nonlinearity, pass_key=False
),
where_input=lambda input, state: state.output,
where_state=lambda state: state.output,
),
}
return spec

On the other hand, a function like pprint_model_spec shows exactly the components in a model instance, briefly, in the order they are actually included.

@mlprt mlprt added the enhancement New feature or request label Feb 25, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request
Projects
None yet
Development

No branches or pull requests

1 participant