-
Notifications
You must be signed in to change notification settings - Fork 14
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
module.init_states from within jit #591
Labels
Comments
This is my current workaround, based off the source of def data_init_states(module: jx.Module, delta_t: float = 0.025, params: Optional[list[dict[str, jnp.ndarray]]] = None, param_state: Optional[list[dict]] = None):
"""Returns parameter states for channels such that they are initialized
in their steady state.
This considers the voltages and parameters of each compartment.
Args:
delta_t: Passed on to `channel.init_state()`.
params: Params of model, optional.
param_state: A optional list of dictionaries as returned by `data_set`
"""
nodes = module.base.nodes
states = module.base._get_states_from_nodes_and_edges()
pstate = [] if params is None else params_to_pstate(params, module.base.indices_set_by_trainables)
params = module.base.get_all_parameters(pstate, voltage_solver="jaxley.thomas")
for channel in module.base.channels:
name = channel._name
nodes_with_channel = nodes.loc[nodes[name]] # nodes[name] is a bool flag
channel_indices = nodes_with_channel["global_comp_index"].to_numpy()
voltages = nodes_with_channel["v"].to_numpy()
channel_states = {
key: states[key][channel_indices]
for key in channel.channel_states.keys()
}
channel_params = {
key: params[key][channel_indices]
for key in channel.channel_params.keys()
}
init_state = channel.init_state(channel_states, voltages, channel_params, delta_t)
for key, val in init_state.items():
added_param_state = {
"key": key,
"indices": jnp.asarray(channel_indices).reshape(-1, 1), # (num_params, num_comps_per_param)
"val": jnp.atleast_1d(jnp.asarray(val)),
}
if param_state is None:
param_state = [added_param_state]
else:
param_state = param_state + [added_param_state]
return param_state |
Hi there, yes, absolutely agreed! We will tackle this for Michael |
Also being discussed in #508. |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
I'm fitting some channel parameters that would alter the return value of
channel.init_state
. It would be nice to have amodule.init_states
function alternative (data_init_states
?) that returns a statelist[dict]
that can then be passed to theparam_state
keyword argument ofjx.integrate
. (If there's a better current way to do this lmk!)The text was updated successfully, but these errors were encountered: