You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
It's nice that we can intervene on a model's state PyTree, which in Feedbax is an equinox.Module object that is modified out-of-place by the model itself, which is an instance of AbstractStagedModel. However, sometimes we want to modify the parameters of the model object—such as the weights in a neural network. One such use case is online learning, where we want to update the parameters of the model within a single trial (between time steps).
Why an AbstractIntervenor can't intervene on model parameters
Intervenors added to an AbstractStagedModel cannot modify its fields. In JAX/Equinox we treat instances as immutable and update them out-of-place; this cannot be done from within the instance itself.
One option is to move the AbstractStagedModel fields we want to modify online into the model's respective AbstractState class, so that they "become" state variables. I'm not sure that's wise:
Two different subclasses of AbstractStagedModel can depend on the same subclass of AbstractState. But they might not share parameters, so moving their parameters into AbstractState might lead to a proliferation of AbstractState subclasses.
Model classes that refer to parameters stored all together in a field are less readable—we refer to self.params.x instead of self.x.
On the other hand, keeping model parameters in a separate PyTree is the strategy used for subclasses of AbstractIntervenor, so their instances can be controlled on a trial-by-trial basis by model inputs provided by an AbstractTask. I've accepted the downsides in this case, because AbstractIntervenor subclasses tend to be have small implementations and do not have "state" of their own, aside from their parameters.
Currently, Feedbax only allows offline updates to models: after a batch of trials is evaluated in TaskTrainer, we update the model by gradient descent with respect to the loss function. We also allow the user to specify additional functions (in the field model_update_funcs) which perform surgery on the model, given the evaluated states of the model for the current batch. Note that TaskTrainer has no access to the model/states on different time steps, only the state history following evaluation of the entire batch.
Given the current class structure of Feedbax, the logical place to intervene on model parameters is from within AbstractIterator objects, which iterate models over time.
I think this should be relatively straightforward, though it will require at least one significant modification: AbstractIterator currently has a field _step to which an instance of AbstractModel—the model which is iterated—is assigned. This field should not be a field. Instead, step should be passed as an argument to the __call__ method of the subclass of AbstractIterator. This way, the model step can be altered out-of-place on each time step of the trial, then returned in its final form—potentially along with its entire history.
I will probably try to implement this soon.
Note that this solution would only allow updates to parameters once per time step. This seems sufficient to me. In principle, turning model parameters into AbstractState (or similar) fields is a more general solution, since we could intervene on those fields at arbitrary points during execution of a single step of the model. However, I doubt that is necessary.
The text was updated successfully, but these errors were encountered:
Even if AbstractState is eliminated (#24), this issue should be the same. We'd just be potentially adding model parameters to some other subclass of equinox.Module, instead of keeping them as fields of the model module.
This issue may look different if we replace the staged model approach with a DAG approach (#28), but unlike some other issues I'm not tagging this one as "won't fix" since this isn't a problem that will simply be obviated by the change in approach. We'll still have to consider how to treat dynamic parameters differently from model state.
It's nice that we can intervene on a model's state PyTree, which in Feedbax is an
equinox.Module
object that is modified out-of-place by the model itself, which is an instance ofAbstractStagedModel
. However, sometimes we want to modify the parameters of the model object—such as the weights in a neural network. One such use case is online learning, where we want to update the parameters of the model within a single trial (between time steps).Why an
AbstractIntervenor
can't intervene on model parametersIntervenors added to an
AbstractStagedModel
cannot modify its fields. In JAX/Equinox we treat instances as immutable and update them out-of-place; this cannot be done from within the instance itself.One option is to move the
AbstractStagedModel
fields we want to modify online into the model's respectiveAbstractState
class, so that they "become" state variables. I'm not sure that's wise:AbstractStagedModel
can depend on the same subclass ofAbstractState
. But they might not share parameters, so moving their parameters intoAbstractState
might lead to a proliferation ofAbstractState
subclasses.self.params.x
instead ofself.x
.On the other hand, keeping model parameters in a separate PyTree is the strategy used for subclasses of
AbstractIntervenor
, so their instances can be controlled on a trial-by-trial basis by model inputs provided by anAbstractTask
. I've accepted the downsides in this case, becauseAbstractIntervenor
subclasses tend to be have small implementations and do not have "state" of their own, aside from their parameters.Currently, Feedbax only allows offline updates to models: after a batch of trials is evaluated in
TaskTrainer
, we update the model by gradient descent with respect to the loss function. We also allow the user to specify additional functions (in the fieldmodel_update_funcs
) which perform surgery on the model, given the evaluated states of the model for the current batch. Note thatTaskTrainer
has no access to the model/states on different time steps, only the state history following evaluation of the entire batch.Given the current class structure of Feedbax, the logical place to intervene on model parameters is from within
AbstractIterator
objects, which iterate models over time.I think this should be relatively straightforward, though it will require at least one significant modification:
AbstractIterator
currently has a field_step
to which an instance ofAbstractModel
—the model which is iterated—is assigned. This field should not be a field. Instead,step
should be passed as an argument to the__call__
method of the subclass ofAbstractIterator
. This way, the model step can be altered out-of-place on each time step of the trial, then returned in its final form—potentially along with its entire history.I will probably try to implement this soon.
Note that this solution would only allow updates to parameters once per time step. This seems sufficient to me. In principle, turning model parameters into
AbstractState
(or similar) fields is a more general solution, since we could intervene on those fields at arbitrary points during execution of a single step of the model. However, I doubt that is necessary.The text was updated successfully, but these errors were encountered: