You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
SimpleFeedback allows for multiple channels of feedback to the neural network, with different delays and noise. For example, one typical configuration is to feed back "proprioceptive" variables (e.g. arm joint configuration, or muscle states) at a short delay, and "visual" variables (e.g. position of the end of the arm) at a longer delay.
In particular, SimpleFeedback:
has a field channels: PyTree[Channel]. At construction time, the user supplies a PyTree[ChannelSpec] or a container of mappings; this is used to construct a PyTree[Channel], which is used to construct a MultiModel . However, see: Remove feedbax.channel.ChannelSpec? #3.
MultiModel is a subclass of AbstractModel that has a field models: PyTree[AbstractModel], and expects to be passed state and input whose tree structures match that of models. When called, it maps the models, inputs, and states:
# TODO: This is hacky, because I want to pass intervenor stuff through entirely. See `staged`
returnjax.tree_map(
lambdamodel, input_, state, key: model(
ModelInput(input_, input.intervene), state, key
),
self.models,
input.input,
state,
self._get_keys(key),
is_leaf=lambdax: isinstance(x, AbstractModel),
)
The PyTree structure of input.input matches models because of the tree_map performed in the definition of where_input for the "update_feedback" stage.
The structure of states matches too, because MultiModel (like all AbstractModel subclasses) provides an init method that returns a PyTree[ChannelState], and this is used to generate any initial state that is passed to the model.
Is there a better way to include a PyTree of similar components in a model, that are all executed as part of a single model stage? With the current approach, intervenors can be added to individual Channel objects, but it may be kind of inconvenient to refer to those objects (e.g. my_simple_feedback.channels.models['vision']).
I'm not sure how vmapping could be used here, as different channels can carry data of different shapes and dtypes.
The use of ModelInput in MultiModel is also not ideal, in particular because I think it makes sense for MultiModel to be a subclass of AbstractModel and not AbstractStagedModel; however, ModelInput is specifically used for carrying intervenor parameters along with other model inputs, and intervenors are associated with instances of AbstractStagedModel, and not with AbstractModel in general. See #12 for a more general discussion of ModelInput.
The text was updated successfully, but these errors were encountered:
SimpleFeedback
allows for multiple channels of feedback to the neural network, with different delays and noise. For example, one typical configuration is to feed back "proprioceptive" variables (e.g. arm joint configuration, or muscle states) at a short delay, and "visual" variables (e.g. position of the end of the arm) at a longer delay.In particular,
SimpleFeedback
:channels: PyTree[Channel]
. At construction time, the user supplies aPyTree[ChannelSpec]
or a container of mappings; this is used to construct aPyTree[Channel]
, which is used to construct aMultiModel
. However, see: Removefeedbax.channel.ChannelSpec
? #3."update_feedback"
:feedbax/feedbax/bodies.py
Lines 175 to 183 in b73dfb8
MultiModel
is a subclass ofAbstractModel
that has a fieldmodels: PyTree[AbstractModel]
, and expects to be passedstate
andinput
whose tree structures match that ofmodels
. When called, it maps the models, inputs, and states:feedbax/feedbax/_model.py
Lines 129 to 146 in b73dfb8
The PyTree structure of
input.input
matchesmodels
because of thetree_map
performed in the definition ofwhere_input
for the"update_feedback"
stage.The structure of
states
matches too, becauseMultiModel
(like allAbstractModel
subclasses) provides aninit
method that returns aPyTree[ChannelState]
, and this is used to generate any initial state that is passed to the model.Is there a better way to include a PyTree of similar components in a model, that are all executed as part of a single model stage? With the current approach, intervenors can be added to individual
Channel
objects, but it may be kind of inconvenient to refer to those objects (e.g.my_simple_feedback.channels.models['vision']
).I'm not sure how vmapping could be used here, as different channels can carry data of different shapes and dtypes.
The use of
ModelInput
inMultiModel
is also not ideal, in particular because I think it makes sense forMultiModel
to be a subclass ofAbstractModel
and notAbstractStagedModel
; however,ModelInput
is specifically used for carrying intervenor parameters along with other model inputs, and intervenors are associated with instances ofAbstractStagedModel
, and not withAbstractModel
in general. See #12 for a more general discussion ofModelInput
.The text was updated successfully, but these errors were encountered: