Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Include any neural network as a component of a staged model #2

Open
mlprt opened this issue Feb 15, 2024 · 1 comment
Open

Include any neural network as a component of a staged model #2

mlprt opened this issue Feb 15, 2024 · 1 comment
Labels
enhancement New feature or request structure

Comments

@mlprt
Copy link
Owner

mlprt commented Feb 15, 2024

It should be straightforward to allow any Equinox-based neural network to be component of an AbstractStagedModel, such as SimpleFeedback, and be called during one of the model stages.

We would need to:

  1. if it doesn't take a state/hidden argument (e.g. not an RNN), wrap it to ignore the state argument that AbstractStagedModel.__call__ will try to pass it,
  2. associate it with a generic NetworkState-like PyTree that has a single leaf, which stores the output of the module. Alternatively, keep a single Array in the PyTree of any model of which the network is a component (e.g. SimpleFeedback).

On the other hand, we could specify the network itself as an AbstractStagedModel where network layers correspond to distinct stages, and where the activities of different layers may be kept as part of the state. In that case, the user can add interventions to these states without needing to redesign the model.

SimpleStagedNetwork is the prototype for a neural network AbstractStagedModel.

@mlprt mlprt added help wanted Extra attention is needed staged models and removed help wanted Extra attention is needed labels Feb 15, 2024
@mlprt mlprt added structure enhancement New feature or request and removed staged models labels Feb 28, 2024
@mlprt mlprt changed the title Neural networks as staged models Include any neural network as a component of a staged model Feb 28, 2024
@mlprt
Copy link
Owner Author

mlprt commented Jul 26, 2024

The concern over associating third-party modules with states should probably be simpler if we replace the staged model approach with a DAG approach (#28).

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request structure
Projects
None yet
Development

No branches or pull requests

1 participant