Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

The suitability of WhereDict: lambdas as keys #14

Open
mlprt opened this issue Feb 21, 2024 · 2 comments
Open

The suitability of WhereDict: lambdas as keys #14

mlprt opened this issue Feb 21, 2024 · 2 comments
Labels
question Further information is requested

Comments

@mlprt
Copy link
Owner

mlprt commented Feb 21, 2024

AbstractTask provides data with which to initialize subsets of the model state, at the start of each task trial.

This could be accomplished by providing a PyTree with the same structure as the full model state, with None at all leaves except those to be initialized, and then using eqx.combine to replace the model substates with initial values provided. However, this would require that each type of AbstractTask be associated with a particular type of state PyTree, whereas in principle a task type should be compatible with any model whose associated state contains at least the substates 1) to be initialized and 2) that are targets/part of the loss computation. For example, it shouldn't matter how complex the PyTree of states for the neural network is, when defining a task in terms of initial and target states for the biomechanical effector.

Therefore, in AbstractTask the initial values for substates are specified (example) as a pairing of a lambda that selects the substate to be initialized from the full model state, with a PyTree of data with the same structure as that substate. Then TaskTrainer performs a series of equinox.tree_at surgeries based on this mapping:

feedbax/feedbax/train.py

Lines 516 to 521 in 2ce8b1c

for where_substate, init_substates in trial_specs.inits.items():
init_states = eqx.tree_at(
where_substate,
init_states,
init_substates,
)

As far as TaskTrainer is concerned, it would be sufficient to provide these pairings as tuple[Callable, PyTree]. However, in general it seems to make sense for the pairing to be a mapping -- first, because there should only be at most a single initialization provided for each substate. But also, if the user -- or a function in feedbax.plot -- wants to access the intialization data from the trial specification, it is more convenient to write (say) trial_spec.init['mechanics.effector'] or even trial_spec.init[lambda state: state.mechanics.effector] than to have to figure out which tuple[Callable, PyTree] contains the Callable that refers to the part of the state they're interested in.

Unfortunately, lambdas cannot simply be used as keys in a mapping. If I define a dict with a key lambda state: state.something, and later try to get the value associated with a newly-defined key lambda s: s.something, I'll encounter a KeyError because lambdas are not hashed according to the function they represent, but by their memory address.

Thus we use WhereDict, which is an OrderedDict that enables limited use of lambdas as keys. In particular, it uses dis to parse the LOAD_ATTR operations in the lambda's bytecode, and constructs an equivalent string representation. For example, lambda x: x.foo.bar is parsed as "foo.bar". This only works when the lambda takes a single argument, and returns a single (nested) attributed access on the argument.

Both a lambda and its string representation can be used as WhereDict key:

assert my_where_dict['foo.bar'] is my_where_dict[lambda x: x.foo.bar]

The downside is the overhead of dis.Bytecode, such that WhereDict is about 100x slower to construct, and 500-1000x slower to access, than OrderedDict. In practice this is not a big deal for our use case, as we only need to do a single construct and a single access on each training batch, leading to an overhead of about 125 us, where a batch normally takes at least 20,000 us. Also, it's unlikely the user will initialize more than a few substates separately, i.e. no more than a few entries per WhereDict.

Is there a better or faster way to do this?

I've considered that it might work to specify the state initialization as a prefix of the model state rather than using a lambda combined with a substate. However, this does not solve the access problem -- we'd still need to map to the prefixes with string keys like "mechanics.effector" so that the user could refer to them easily; however 1) those keys would no longer have a lawful relationship to the surgeries performed to assign the initial states, and 2) the user would still have to access the appropriate leaves from the prefix tree.

@mlprt mlprt added the question Further information is requested label Feb 21, 2024
@mlprt mlprt changed the title The suitability of WhereDict The suitability of WhereDict: lambdas as keys Feb 21, 2024
@mlprt
Copy link
Owner Author

mlprt commented Apr 17, 2024

Given a where function and a tree, we can obtain the paths of the nodes in the tree that are returned by the function using something like this:

import equinox as eqx
import jax.tree_util as jtu

class _NodeWrapper:
    def __init__(self, value):
        self.value = value

def where_func_to_paths(where, tree):
    tree = eqx.tree_at(where, tree, replace_fn=lambda x: _NodeWrapper(x))
    id_tree = jtu.tree_map(id, tree, is_leaf=lambda x: isinstance(x, _NodeWrapper))
    node_ids = where(id_tree)

    paths_by_id = {leaf_id: path for path, leaf_id in jtu.tree_leaves_with_path(
        jtu.tree_map(lambda x: x if x in node_ids else None, id_tree)
    )}

    paths = jtu.tree_map(lambda node_id: paths_by_id[node_id], node_ids)

    return paths

Profiling a couple of typical cases of where functions with a SimpleFeedbackState tree, this solution seems to be 2-3x slower than the dis solution. However, as mentioned in the preceding comment, we call this function infrequently so that's probably not an issue.

Note the use of _NodeWrapper so that we can identify paths of nodes, and not just leaves. However, the tree_at call is the slowest part of the function. Perhaps there is some way to use is_leaf in the first tree_map, to avoid the call to tree_at.

Example profile

image

Limitations:

  • Unlike the dis solution, Requires access to a pytree with the appropriate nodes.
  • Assumes that the same object does not appear as different nodes in the tree (i.e. different paths associated with the same object ID), though if that's a concern, we could replace those nodes with unique objects.
  • The where function can't specify both a node, and another node contained within that node, since the outer node will get wrapped in _NodeWrapper which will mask the inner one. This should be solvable.

Advantages:

  • where could return a PyTree of nodes, and this function will give the respective PyTree of paths. On the other hand, the dis solution only works (so far) for a where that selects a single node -- extending it to work for arbitrary PyTrees would require some more complex parsing.
  • Should work for nodes that are accessed by dict keys, sequence indices, etc.; where the current dis solution only works for attribute access.

@mlprt
Copy link
Owner Author

mlprt commented Apr 17, 2024

A third option is to use the where function to construct a string representation directly, like so:

from typing import Any
import jax

class WhereStrConstructor:
    
    def __init__(self, label: str = ""):
        self.label = label

    def __getitem__(self, key: Any):
        if isinstance(key, str):
            key = f"'{key}'"
        elif isinstance(key, type):
            key = key.__name__
        # Add other conditional representations, as needed.
        return WhereStrConstructor("".join([self.label, f"[{key}]"]))
        
    def __getattr__(self, name: str):
        sep = "." if self.label else ""
        return WhereStrConstructor(sep.join([self.label, name]))
    

def where_func_to_labels(where: Callable) -> PyTree[str]:
    return jax.tree_map(lambda x: x.label, where(WhereStrConstructor()))

Advantages:

  • Doesn't need a PyTree; this is really just a way of getting a string representation of a function that returns a PyTree of nodes from a tree;
  • Runs significantly faster than the other solutions; only about 10 us for a typical case;
  • No complications due to tree traversal or bytecode parsing.

Disadvantages:

  • Doesn't give the paths of the selected nodes -- though in most cases this shouldn't be a problem since we already have access to the where function.
  • Is totally oblivious to the types and actual structure of the PyTrees that the where is relevant to. This shouldn't be an issue since any problems with the actual structure of the tree, would anyway lead to errors being raised elsewhere.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
question Further information is requested
Projects
None yet
Development

No branches or pull requests

1 participant