-
Notifications
You must be signed in to change notification settings - Fork 0
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
The suitability of WhereDict
: lambdas as keys
#14
Comments
WhereDict
WhereDict
: lambdas as keys
Given a import equinox as eqx
import jax.tree_util as jtu
class _NodeWrapper:
def __init__(self, value):
self.value = value
def where_func_to_paths(where, tree):
tree = eqx.tree_at(where, tree, replace_fn=lambda x: _NodeWrapper(x))
id_tree = jtu.tree_map(id, tree, is_leaf=lambda x: isinstance(x, _NodeWrapper))
node_ids = where(id_tree)
paths_by_id = {leaf_id: path for path, leaf_id in jtu.tree_leaves_with_path(
jtu.tree_map(lambda x: x if x in node_ids else None, id_tree)
)}
paths = jtu.tree_map(lambda node_id: paths_by_id[node_id], node_ids)
return paths Profiling a couple of typical cases of Note the use of Limitations:
Advantages:
|
A third option is to use the from typing import Any
import jax
class WhereStrConstructor:
def __init__(self, label: str = ""):
self.label = label
def __getitem__(self, key: Any):
if isinstance(key, str):
key = f"'{key}'"
elif isinstance(key, type):
key = key.__name__
# Add other conditional representations, as needed.
return WhereStrConstructor("".join([self.label, f"[{key}]"]))
def __getattr__(self, name: str):
sep = "." if self.label else ""
return WhereStrConstructor(sep.join([self.label, name]))
def where_func_to_labels(where: Callable) -> PyTree[str]:
return jax.tree_map(lambda x: x.label, where(WhereStrConstructor())) Advantages:
Disadvantages:
|
AbstractTask
provides data with which to initialize subsets of the model state, at the start of each task trial.This could be accomplished by providing a PyTree with the same structure as the full model state, with
None
at all leaves except those to be initialized, and then usingeqx.combine
to replace the model substates with initial values provided. However, this would require that each type ofAbstractTask
be associated with a particular type of state PyTree, whereas in principle a task type should be compatible with any model whose associated state contains at least the substates 1) to be initialized and 2) that are targets/part of the loss computation. For example, it shouldn't matter how complex the PyTree of states for the neural network is, when defining a task in terms of initial and target states for the biomechanical effector.Therefore, in
AbstractTask
the initial values for substates are specified (example) as a pairing of a lambda that selects the substate to be initialized from the full model state, with a PyTree of data with the same structure as that substate. ThenTaskTrainer
performs a series ofequinox.tree_at
surgeries based on this mapping:feedbax/feedbax/train.py
Lines 516 to 521 in 2ce8b1c
As far as
TaskTrainer
is concerned, it would be sufficient to provide these pairings astuple[Callable, PyTree]
. However, in general it seems to make sense for the pairing to be a mapping -- first, because there should only be at most a single initialization provided for each substate. But also, if the user -- or a function infeedbax.plot
-- wants to access the intialization data from the trial specification, it is more convenient to write (say)trial_spec.init['mechanics.effector']
or eventrial_spec.init[lambda state: state.mechanics.effector]
than to have to figure out whichtuple[Callable, PyTree]
contains theCallable
that refers to the part of the state they're interested in.Unfortunately, lambdas cannot simply be used as keys in a mapping. If I define a dict with a key
lambda state: state.something
, and later try to get the value associated with a newly-defined keylambda s: s.something
, I'll encounter aKeyError
because lambdas are not hashed according to the function they represent, but by their memory address.Thus we use
WhereDict
, which is anOrderedDict
that enables limited use of lambdas as keys. In particular, it usesdis
to parse theLOAD_ATTR
operations in the lambda's bytecode, and constructs an equivalent string representation. For example,lambda x: x.foo.bar
is parsed as"foo.bar"
. This only works when the lambda takes a single argument, and returns a single (nested) attributed access on the argument.Both a lambda and its string representation can be used as
WhereDict
key:The downside is the overhead of
dis.Bytecode
, such thatWhereDict
is about 100x slower to construct, and 500-1000x slower to access, thanOrderedDict
. In practice this is not a big deal for our use case, as we only need to do a single construct and a single access on each training batch, leading to an overhead of about 125 us, where a batch normally takes at least 20,000 us. Also, it's unlikely the user will initialize more than a few substates separately, i.e. no more than a few entries perWhereDict
.Is there a better or faster way to do this?
I've considered that it might work to specify the state initialization as a prefix of the model state rather than using a lambda combined with a substate. However, this does not solve the access problem -- we'd still need to map to the prefixes with string keys like
"mechanics.effector"
so that the user could refer to them easily; however 1) those keys would no longer have a lawful relationship to the surgeries performed to assign the initial states, and 2) the user would still have to access the appropriate leaves from the prefix tree.The text was updated successfully, but these errors were encountered: