Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Intervening on model parameters #21

Open
mlprt opened this issue Feb 26, 2024 · 2 comments
Open

Intervening on model parameters #21

mlprt opened this issue Feb 26, 2024 · 2 comments
Labels
enhancement New feature or request structure

Comments

@mlprt
Copy link
Owner

mlprt commented Feb 26, 2024

It's nice that we can intervene on a model's state PyTree, which in Feedbax is an equinox.Module object that is modified out-of-place by the model itself, which is an instance of AbstractStagedModel. However, sometimes we want to modify the parameters of the model object—such as the weights in a neural network. One such use case is online learning, where we want to update the parameters of the model within a single trial (between time steps).

Why an AbstractIntervenor can't intervene on model parameters

Intervenors added to an AbstractStagedModel cannot modify its fields. In JAX/Equinox we treat instances as immutable and update them out-of-place; this cannot be done from within the instance itself.

One option is to move the AbstractStagedModel fields we want to modify online into the model's respective AbstractState class, so that they "become" state variables. I'm not sure that's wise:

  1. Two different subclasses of AbstractStagedModel can depend on the same subclass of AbstractState. But they might not share parameters, so moving their parameters into AbstractState might lead to a proliferation of AbstractState subclasses.
  2. Model classes that refer to parameters stored all together in a field are less readable—we refer to self.params.x instead of self.x.

On the other hand, keeping model parameters in a separate PyTree is the strategy used for subclasses of AbstractIntervenor, so their instances can be controlled on a trial-by-trial basis by model inputs provided by an AbstractTask. I've accepted the downsides in this case, because AbstractIntervenor subclasses tend to be have small implementations and do not have "state" of their own, aside from their parameters.


Currently, Feedbax only allows offline updates to models: after a batch of trials is evaluated in TaskTrainer, we update the model by gradient descent with respect to the loss function. We also allow the user to specify additional functions (in the field model_update_funcs) which perform surgery on the model, given the evaluated states of the model for the current batch. Note that TaskTrainer has no access to the model/states on different time steps, only the state history following evaluation of the entire batch.

Given the current class structure of Feedbax, the logical place to intervene on model parameters is from within AbstractIterator objects, which iterate models over time.

I think this should be relatively straightforward, though it will require at least one significant modification: AbstractIterator currently has a field _step to which an instance of AbstractModel—the model which is iterated—is assigned. This field should not be a field. Instead, step should be passed as an argument to the __call__ method of the subclass of AbstractIterator. This way, the model step can be altered out-of-place on each time step of the trial, then returned in its final form—potentially along with its entire history.

I will probably try to implement this soon.

Note that this solution would only allow updates to parameters once per time step. This seems sufficient to me. In principle, turning model parameters into AbstractState (or similar) fields is a more general solution, since we could intervene on those fields at arbitrary points during execution of a single step of the model. However, I doubt that is necessary.

@mlprt mlprt added the enhancement New feature or request label Feb 26, 2024
@mlprt
Copy link
Owner Author

mlprt commented Feb 29, 2024

Even if AbstractState is eliminated (#24), this issue should be the same. We'd just be potentially adding model parameters to some other subclass of equinox.Module, instead of keeping them as fields of the model module.

@mlprt
Copy link
Owner Author

mlprt commented Jul 26, 2024

This issue may look different if we replace the staged model approach with a DAG approach (#28), but unlike some other issues I'm not tagging this one as "won't fix" since this isn't a problem that will simply be obviated by the change in approach. We'll still have to consider how to treat dynamic parameters differently from model state.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request structure
Projects
None yet
Development

No branches or pull requests

1 participant