mlprt / feedbax

Optimal feedback control + interventions, in JAX.
https://docs.lprt.ca/feedbax
Apache License 2.0
1 stars 0 forks source link

The suitability of `WhereDict`: lambdas as keys #14

Open mlprt opened 4 months ago

mlprt commented 4 months ago

AbstractTask provides data with which to initialize subsets of the model state, at the start of each task trial.

This could be accomplished by providing a PyTree with the same structure as the full model state, with None at all leaves except those to be initialized, and then using eqx.combine to replace the model substates with initial values provided. However, this would require that each type of AbstractTask be associated with a particular type of state PyTree, whereas in principle a task type should be compatible with any model whose associated state contains at least the substates 1) to be initialized and 2) that are targets/part of the loss computation. For example, it shouldn't matter how complex the PyTree of states for the neural network is, when defining a task in terms of initial and target states for the biomechanical effector.

Therefore, in AbstractTask the initial values for substates are specified (example) as a pairing of a lambda that selects the substate to be initialized from the full model state, with a PyTree of data with the same structure as that substate. Then TaskTrainer performs a series of equinox.tree_at surgeries based on this mapping:

https://github.com/mlprt/feedbax/blob/2ce8b1c25b5790dc0a41a5812eeb6d830b3a151f/feedbax/train.py#L516-L521

As far as TaskTrainer is concerned, it would be sufficient to provide these pairings as tuple[Callable, PyTree]. However, in general it seems to make sense for the pairing to be a mapping -- first, because there should only be at most a single initialization provided for each substate. But also, if the user -- or a function in feedbax.plot -- wants to access the intialization data from the trial specification, it is more convenient to write (say) trial_spec.init['mechanics.effector'] or even trial_spec.init[lambda state: state.mechanics.effector] than to have to figure out which tuple[Callable, PyTree] contains the Callable that refers to the part of the state they're interested in.

Unfortunately, lambdas cannot simply be used as keys in a mapping. If I define a dict with a key lambda state: state.something, and later try to get the value associated with a newly-defined key lambda s: s.something, I'll encounter a KeyError because lambdas are not hashed according to the function they represent, but by their memory address.

Thus we use WhereDict, which is an OrderedDict that enables limited use of lambdas as keys. In particular, it uses dis to parse the LOAD_ATTR operations in the lambda's bytecode, and constructs an equivalent string representation. For example, lambda x: x.foo.bar is parsed as "foo.bar". This only works when the lambda takes a single argument, and returns a single (nested) attributed access on the argument.

Both a lambda and its string representation can be used as WhereDict key:

assert my_where_dict['foo.bar'] is my_where_dict[lambda x: x.foo.bar]

The downside is the overhead of dis.Bytecode, such that WhereDict is about 100x slower to construct, and 500-1000x slower to access, than OrderedDict. In practice this is not a big deal for our use case, as we only need to do a single construct and a single access on each training batch, leading to an overhead of about 125 us, where a batch normally takes at least 20,000 us. Also, it's unlikely the user will initialize more than a few substates separately, i.e. no more than a few entries per WhereDict.

Is there a better or faster way to do this?

I've considered that it might work to specify the state initialization as a prefix of the model state rather than using a lambda combined with a substate. However, this does not solve the access problem -- we'd still need to map to the prefixes with string keys like "mechanics.effector" so that the user could refer to them easily; however 1) those keys would no longer have a lawful relationship to the surgeries performed to assign the initial states, and 2) the user would still have to access the appropriate leaves from the prefix tree.

mlprt commented 2 months ago

Given a where function and a tree, we can obtain the paths of the nodes in the tree that are returned by the function using something like this:

import equinox as eqx
import jax.tree_util as jtu

class _NodeWrapper:
    def __init__(self, value):
        self.value = value

def where_func_to_paths(where, tree):
    tree = eqx.tree_at(where, tree, replace_fn=lambda x: _NodeWrapper(x))
    id_tree = jtu.tree_map(id, tree, is_leaf=lambda x: isinstance(x, _NodeWrapper))
    node_ids = where(id_tree)

    paths_by_id = {leaf_id: path for path, leaf_id in jtu.tree_leaves_with_path(
        jtu.tree_map(lambda x: x if x in node_ids else None, id_tree)
    )}

    paths = jtu.tree_map(lambda node_id: paths_by_id[node_id], node_ids)

    return paths

Profiling a couple of typical cases of where functions with a SimpleFeedbackState tree, this solution seems to be 2-3x slower than the dis solution. However, as mentioned in the preceding comment, we call this function infrequently so that's probably not an issue.

Note the use of _NodeWrapper so that we can identify paths of nodes, and not just leaves. However, the tree_at call is the slowest part of the function. Perhaps there is some way to use is_leaf in the first tree_map, to avoid the call to tree_at.

Example profile

![image](https://github.com/mlprt/feedbax/assets/10875986/5e7fff4e-c576-421a-863f-0cff7af2837e)

Limitations:

Advantages:

mlprt commented 2 months ago

A third option is to use the where function to construct a string representation directly, like so:

from typing import Any
import jax

class WhereStrConstructor:

    def __init__(self, label: str = ""):
        self.label = label

    def __getitem__(self, key: Any):
        if isinstance(key, str):
            key = f"'{key}'"
        elif isinstance(key, type):
            key = key.__name__
        # Add other conditional representations, as needed.
        return WhereStrConstructor("".join([self.label, f"[{key}]"]))

    def __getattr__(self, name: str):
        sep = "." if self.label else ""
        return WhereStrConstructor(sep.join([self.label, name]))

def where_func_to_labels(where: Callable) -> PyTree[str]:
    return jax.tree_map(lambda x: x.label, where(WhereStrConstructor()))

Advantages:

Disadvantages: