Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove feedbax.channel.ChannelSpec? #3

Open
mlprt opened this issue Feb 16, 2024 · 0 comments
Open

Remove feedbax.channel.ChannelSpec? #3

mlprt opened this issue Feb 16, 2024 · 0 comments

Comments

@mlprt
Copy link
Owner

mlprt commented Feb 16, 2024

On its instantiation, feedbax.channel.Channel is passed an input_proto -- a PyTree with the same type (structure & array shapes) as the PyTree the channel will queue. It would be inconvenient to ask the user to provide input_proto to construct their model, since it can be inferred from the rest of the model/state structure. What they should provide is a where lambda that picks out the part of the state that will be sent through the channel.

Currently, ChannelSpec is used for specifying a Channel instance, without constructing it. It mostly replicates the field structure of Channel, but instead of storing an input_proto it stores a where lambda.

class ChannelSpec(eqx.Module):
"""Specifies how to build a [`Channel`][feedbax.channel.Channel], with respect to the state PyTree of its owner.
Attributes:
where: A function that selects the subtree of feedback states.
delay: The number of previous inputs to store in the queue.
noise_std: The standard deviation of the noise to add to the output.
"""
where: Callable[[AbstractState], PyTree[Array]]
delay: int = 0
noise_std: Optional[float] = None

Currently, only SimpleFeedback implements the conversion of ChannelSpec to Channel, by assuming that the where lambda refers specifically to a part of SimpleFeedbackState.

This whole setup seems hacky and I'm trying to find a better way.

A first though is that Channel could have a where field similarly to ChannelSpec, and no field input_proto. In that case, the user could just construct and pass a Channel instead of a ChannelSpec, when constructing SimpleFeedback. During the execution of its stages, SimpleFeedback would pass its entire state to the Channel, whose where would pick out the substate internally.

However:

  • passing the entire SimpleFeedbackState to Channel is at odds with the way where_state should be specified in AbstractStagedModel.model_spec: we'd have to do model surgery from within Channel, to replace the ChannelState component of SimpleFeedbackState, and return an entire SimpleFeedbackState. Then is would not be clear from SimpleFeedback.model_spec, which substate of SimpleFeedbackState is actually relevant to Channel.
  • Similarly, we would have to define the generic type argument of Channel as SimpleFeedbackState to have the signature of its __call__ method agree with AbstractStagedModel.
  • Channel.init cannot provide a valid initial ChannelState, without access to an actual example of the subtree that's selected by where. In the current situation, the subtree example is provided by the Channel's self.input_proto. We do not want to pass input_proto to init as an argument, since all init methods are designed as providing a default state.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant