You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
On its instantiation, feedbax.channel.Channel is passed an input_proto -- a PyTree with the same type (structure & array shapes) as the PyTree the channel will queue. It would be inconvenient to ask the user to provide input_proto to construct their model, since it can be inferred from the rest of the model/state structure. What they should provide is a where lambda that picks out the part of the state that will be sent through the channel.
Currently, ChannelSpec is used for specifying a Channel instance, without constructing it. It mostly replicates the field structure of Channel, but instead of storing an input_proto it stores a where lambda.
"""Specifies how to build a [`Channel`][feedbax.channel.Channel], with respect to the state PyTree of its owner.
Attributes:
where: A function that selects the subtree of feedback states.
delay: The number of previous inputs to store in the queue.
noise_std: The standard deviation of the noise to add to the output.
"""
where: Callable[[AbstractState], PyTree[Array]]
delay: int=0
noise_std: Optional[float] =None
Currently, only SimpleFeedback implements the conversion of ChannelSpec to Channel, by assuming that the where lambda refers specifically to a part of SimpleFeedbackState.
This whole setup seems hacky and I'm trying to find a better way.
A first though is that Channel could have a where field similarly to ChannelSpec, and no field input_proto. In that case, the user could just construct and pass a Channel instead of a ChannelSpec, when constructing SimpleFeedback. During the execution of its stages, SimpleFeedback would pass its entire state to the Channel, whose where would pick out the substate internally.
However:
passing the entire SimpleFeedbackState to Channel is at odds with the way where_state should be specified in AbstractStagedModel.model_spec: we'd have to do model surgery from within Channel, to replace the ChannelState component of SimpleFeedbackState, and return an entire SimpleFeedbackState. Then is would not be clear from SimpleFeedback.model_spec, which substate of SimpleFeedbackState is actually relevant to Channel.
Similarly, we would have to define the generic type argument of Channel as SimpleFeedbackState to have the signature of its __call__ method agree with AbstractStagedModel.
Channel.init cannot provide a valid initial ChannelState, without access to an actual example of the subtree that's selected by where. In the current situation, the subtree example is provided by the Channel's self.input_proto. We do not want to pass input_proto to init as an argument, since all init methods are designed as providing a default state.
The text was updated successfully, but these errors were encountered:
On its instantiation,
feedbax.channel.Channel
is passed aninput_proto
-- a PyTree with the same type (structure & array shapes) as the PyTree the channel will queue. It would be inconvenient to ask the user to provideinput_proto
to construct their model, since it can be inferred from the rest of the model/state structure. What they should provide is awhere
lambda that picks out the part of the state that will be sent through the channel.Currently,
ChannelSpec
is used for specifying aChannel
instance, without constructing it. It mostly replicates the field structure ofChannel
, but instead of storing aninput_proto
it stores awhere
lambda.feedbax/feedbax/channel.py
Lines 42 to 53 in 147fb42
Currently, only
SimpleFeedback
implements the conversion ofChannelSpec
toChannel
, by assuming that thewhere
lambda refers specifically to a part ofSimpleFeedbackState
.This whole setup seems hacky and I'm trying to find a better way.
A first though is that
Channel
could have awhere
field similarly toChannelSpec
, and no fieldinput_proto
. In that case, the user could just construct and pass aChannel
instead of aChannelSpec
, when constructingSimpleFeedback
. During the execution of its stages,SimpleFeedback
would pass its entire state to theChannel
, whosewhere
would pick out the substate internally.However:
SimpleFeedbackState
toChannel
is at odds with the waywhere_state
should be specified inAbstractStagedModel.model_spec
: we'd have to do model surgery from withinChannel
, to replace theChannelState
component ofSimpleFeedbackState
, and return an entireSimpleFeedbackState
. Then is would not be clear fromSimpleFeedback.model_spec
, which substate ofSimpleFeedbackState
is actually relevant toChannel
.Channel
asSimpleFeedbackState
to have the signature of its__call__
method agree withAbstractStagedModel
.Channel.init
cannot provide a valid initialChannelState
, without access to an actual example of the subtree that's selected bywhere
. In the current situation, the subtree example is provided by theChannel
'sself.input_proto
. We do not want to passinput_proto
toinit
as an argument, since allinit
methods are designed as providing a default state.The text was updated successfully, but these errors were encountered: