Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

module.init_states from within jit #591

Open
chaseking opened this issue Feb 21, 2025 · 3 comments
Open

module.init_states from within jit #591

chaseking opened this issue Feb 21, 2025 · 3 comments
Labels

Comments

@chaseking
Copy link

chaseking commented Feb 21, 2025

I'm fitting some channel parameters that would alter the return value of channel.init_state. It would be nice to have a module.init_states function alternative (data_init_states?) that returns a state list[dict] that can then be passed to the param_state keyword argument of jx.integrate. (If there's a better current way to do this lmk!)

@chaseking
Copy link
Author

This is my current workaround, based off the source of init_states:

def data_init_states(module: jx.Module, delta_t: float = 0.025, params: Optional[list[dict[str, jnp.ndarray]]] = None, param_state: Optional[list[dict]] = None):
    """Returns parameter states for channels such that they are initialized
    in their steady state.

    This considers the voltages and parameters of each compartment.

    Args:
        delta_t: Passed on to `channel.init_state()`.
        params: Params of model, optional.
        param_state: A optional list of dictionaries as returned by `data_set`
    """
    nodes = module.base.nodes
    states = module.base._get_states_from_nodes_and_edges()
    pstate = [] if params is None else params_to_pstate(params, module.base.indices_set_by_trainables)
    params = module.base.get_all_parameters(pstate, voltage_solver="jaxley.thomas")

    for channel in module.base.channels:
        name = channel._name
        nodes_with_channel = nodes.loc[nodes[name]]  # nodes[name] is a bool flag
        channel_indices = nodes_with_channel["global_comp_index"].to_numpy()
        voltages = nodes_with_channel["v"].to_numpy()
        channel_states = {
            key: states[key][channel_indices]
            for key in channel.channel_states.keys()
        }
        channel_params = {
            key: params[key][channel_indices]
            for key in channel.channel_params.keys()
        }
        init_state = channel.init_state(channel_states, voltages, channel_params, delta_t)
        for key, val in init_state.items():
            added_param_state = {
                "key": key,
                "indices": jnp.asarray(channel_indices).reshape(-1, 1),  # (num_params, num_comps_per_param)
                "val": jnp.atleast_1d(jnp.asarray(val)),
            }
            if param_state is None:
                param_state = [added_param_state]
            else:
                param_state = param_state + [added_param_state]

    return param_state

@michaeldeistler
Copy link
Contributor

Hi there,

yes, absolutely agreed! We will tackle this for v1.0. Thanks for raising it!

Michael

@jnsbck
Copy link
Contributor

jnsbck commented Feb 24, 2025

Also being discussed in #508.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

No branches or pull requests

3 participants