You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
All Feedbax models (subclasses of AbstractModel) have the signature (input: ModelInput, state: StateT, *, key: PRNGKeyArray) -> StateT. StateT is bound to AbstractState in AbstractModel.StateT is now bound to PyTree[Array] -- see #24.
In AbstractStagedModel subclasses (#19) we perform a sequence of state operations by passing subsets of input and state to model components, and using the return values to make out-of-place updates to state.
An AbstractModel is generally a PyTree containing other AbstractModel nodes; i.e. Feedbax models are hierarchical. Typically, the outermost node in the model PyTree is an instance of Iterator, which is essentially a loop over a single step of the model (e.g. a SimpleFeedback instance) where all of the actual state operations happen.
The input to the outermost model node is not selected from the input to another AbstractModel that contains it, because there is none. Instead, its input is provided by an instance of AbstractTask. This task information is any trial-by-trial data that is unconditional on the internal operations of the model. For example, a reaching task like SimpleReaches will provide the model with the goal position it is expected to reach to, and the model will ultimately forward this to the controller (neural network) component.
An issue arises when we need to schedule interventions on a task/model that already exists. Interventions may change on a trial-by-trial basis. Any systematic trial-by-trial variations are specified by an AbstractTask. In particular, if the parameters of an AbstractIntervenor are expected to change across trials, then an AbstractTask should provide those changing parameters as part of the input to the model. The model will then need to make sure that these parameters are matched up to the right instance of AbstractIntervenor.
Perhaps there is a way for schedule_intervenor to work with AbstractTask to structure the intervention parameters so that, at each level of the model, AbstractStagedModel.__call__ can be made to send them on to the right component, until they reach the component that contains the instance of AbstractIntervenor they pertain to. I have not figured out how to do this.
My current solution is, when an intervenor is scheduled with schedule_intervenor, to assign it a unique string label among all the intervenors aggregated over all levels of a model PyTree. Then, intervention parameters are included in input as a flat mapping from the unique labels, to parameters. This flat mapping is passed as-is down through the hierarchy of model components; every AbstractStagedModel sees the same mapping, and simply tries to match the unique labels of its own intervenors, with those in the mapping.
This is what ModelInput is for: it's an eqx.Module with two fields, input and intervene: input contains the usual task information which, once it reaches the outermost AbstractStagedModel in the model, is selectively passed on to certain component(s) depending on the definition of model_spec (again, typically it's all sent to the neural network). On the other hand, intervene contains the flat mapping of intervention parameters, and is passed on as-is.
So, in AbstractStagedModel.__call__ we see something like:
We need to pass a subset of the model inputs to the current stage: select subinput out of input.input—I haven't thought of a better name. Maybe input.task_input.
If the component to be called is an AbstractModel, it accepts ModelInput and might contain interventions. Therefore we pass a reconstructed ModelInput with the same intervene value (i.e. the flat mapping), but with only subinput as input.
This seems pretty hacky to me and I'm not sure how it should be done better. I've considered adding another argument to the signature of AbstractModel, but that doesn't seem better. Also, I suppose I don't have to use ModelInput at all, and could just type input as a tuple.
The text was updated successfully, but these errors were encountered:
This should be simplified/obviated if we replace the staged model approach with a DAG approach (#28). In particular, input dependencies will be handled lazily. So I am tagging this issue as "won't fix".
All Feedbax models (subclasses of
AbstractModel
) have the signature(input: ModelInput, state: StateT, *, key: PRNGKeyArray) -> StateT
.StateT
is bound toAbstractState
inAbstractModel
.StateT
is now bound toPyTree[Array]
-- see #24.In
AbstractStagedModel
subclasses (#19) we perform a sequence of state operations by passing subsets ofinput
andstate
to model components, and using the return values to make out-of-place updates tostate
.An
AbstractModel
is generally a PyTree containing otherAbstractModel
nodes; i.e. Feedbax models are hierarchical. Typically, the outermost node in the model PyTree is an instance ofIterator
, which is essentially a loop over a single step of the model (e.g. aSimpleFeedback
instance) where all of the actual state operations happen.The
input
to the outermost model node is not selected from theinput
to anotherAbstractModel
that contains it, because there is none. Instead, its input is provided by an instance ofAbstractTask
. This task information is any trial-by-trial data that is unconditional on the internal operations of the model. For example, a reaching task likeSimpleReaches
will provide the model with the goal position it is expected to reach to, and the model will ultimately forward this to the controller (neural network) component.An issue arises when we need to schedule interventions on a task/model that already exists. Interventions may change on a trial-by-trial basis. Any systematic trial-by-trial variations are specified by an
AbstractTask
. In particular, if the parameters of anAbstractIntervenor
are expected to change across trials, then anAbstractTask
should provide those changing parameters as part of theinput
to the model. The model will then need to make sure that these parameters are matched up to the right instance ofAbstractIntervenor
.Perhaps there is a way for
schedule_intervenor
to work withAbstractTask
to structure the intervention parameters so that, at each level of the model,AbstractStagedModel.__call__
can be made to send them on to the right component, until they reach the component that contains the instance ofAbstractIntervenor
they pertain to. I have not figured out how to do this.My current solution is, when an intervenor is scheduled with
schedule_intervenor
, to assign it a unique string label among all the intervenors aggregated over all levels of a model PyTree. Then, intervention parameters are included ininput
as a flat mapping from the unique labels, to parameters. This flat mapping is passed as-is down through the hierarchy of model components; everyAbstractStagedModel
sees the same mapping, and simply tries to match the unique labels of its own intervenors, with those in the mapping.This is what
ModelInput
is for: it's aneqx.Module
with two fields,input
andintervene
:input
contains the usual task information which, once it reaches the outermostAbstractStagedModel
in the model, is selectively passed on to certain component(s) depending on the definition ofmodel_spec
(again, typically it's all sent to the neural network). On the other hand,intervene
contains the flat mapping of intervention parameters, and is passed on as-is.So, in
AbstractStagedModel.__call__
we see something like:feedbax/feedbax/_staged.py
Lines 152 to 160 in 8f080c6
Here, we:
subinput
out ofinput.input
—I haven't thought of a better name. Maybeinput.task_input
.AbstractModel
, it acceptsModelInput
and might contain interventions. Therefore we pass a reconstructedModelInput
with the sameintervene
value (i.e. the flat mapping), but with onlysubinput
asinput
.This seems pretty hacky to me and I'm not sure how it should be done better. I've considered adding another argument to the signature of
AbstractModel
, but that doesn't seem better. Also, I suppose I don't have to useModelInput
at all, and could just typeinput
as a tuple.The text was updated successfully, but these errors were encountered: