Open mlprt opened 4 months ago
Given a where
function and a tree
, we can obtain the paths of the nodes in the tree that are returned by the function using something like this:
import equinox as eqx
import jax.tree_util as jtu
class _NodeWrapper:
def __init__(self, value):
self.value = value
def where_func_to_paths(where, tree):
tree = eqx.tree_at(where, tree, replace_fn=lambda x: _NodeWrapper(x))
id_tree = jtu.tree_map(id, tree, is_leaf=lambda x: isinstance(x, _NodeWrapper))
node_ids = where(id_tree)
paths_by_id = {leaf_id: path for path, leaf_id in jtu.tree_leaves_with_path(
jtu.tree_map(lambda x: x if x in node_ids else None, id_tree)
)}
paths = jtu.tree_map(lambda node_id: paths_by_id[node_id], node_ids)
return paths
Profiling a couple of typical cases of where
functions with a SimpleFeedbackState
tree, this solution seems to be 2-3x slower than the dis
solution. However, as mentioned in the preceding comment, we call this function infrequently so that's probably not an issue.
Note the use of _NodeWrapper
so that we can identify paths of nodes, and not just leaves. However, the tree_at
call is the slowest part of the function. Perhaps there is some way to use is_leaf
in the first tree_map
, to avoid the call to tree_at
.
![image](https://github.com/mlprt/feedbax/assets/10875986/5e7fff4e-c576-421a-863f-0cff7af2837e)
Limitations:
dis
solution, Requires access to a pytree with the appropriate nodes.where
function can't specify both a node, and another node contained within that node, since the outer node will get wrapped in _NodeWrapper
which will mask the inner one. This should be solvable.Advantages:
where
could return a PyTree of nodes, and this function will give the respective PyTree of paths. On the other hand, the dis
solution only works (so far) for a where
that selects a single node -- extending it to work for arbitrary PyTrees would require some more complex parsing.dis
solution only works for attribute access.A third option is to use the where
function to construct a string representation directly, like so:
from typing import Any
import jax
class WhereStrConstructor:
def __init__(self, label: str = ""):
self.label = label
def __getitem__(self, key: Any):
if isinstance(key, str):
key = f"'{key}'"
elif isinstance(key, type):
key = key.__name__
# Add other conditional representations, as needed.
return WhereStrConstructor("".join([self.label, f"[{key}]"]))
def __getattr__(self, name: str):
sep = "." if self.label else ""
return WhereStrConstructor(sep.join([self.label, name]))
def where_func_to_labels(where: Callable) -> PyTree[str]:
return jax.tree_map(lambda x: x.label, where(WhereStrConstructor()))
Advantages:
Disadvantages:
where
function.where
is relevant to. This shouldn't be an issue since any problems with the actual structure of the tree, would anyway lead to errors being raised elsewhere.
AbstractTask
provides data with which to initialize subsets of the model state, at the start of each task trial.This could be accomplished by providing a PyTree with the same structure as the full model state, with
None
at all leaves except those to be initialized, and then usingeqx.combine
to replace the model substates with initial values provided. However, this would require that each type ofAbstractTask
be associated with a particular type of state PyTree, whereas in principle a task type should be compatible with any model whose associated state contains at least the substates 1) to be initialized and 2) that are targets/part of the loss computation. For example, it shouldn't matter how complex the PyTree of states for the neural network is, when defining a task in terms of initial and target states for the biomechanical effector.Therefore, in
AbstractTask
the initial values for substates are specified (example) as a pairing of a lambda that selects the substate to be initialized from the full model state, with a PyTree of data with the same structure as that substate. ThenTaskTrainer
performs a series ofequinox.tree_at
surgeries based on this mapping:https://github.com/mlprt/feedbax/blob/2ce8b1c25b5790dc0a41a5812eeb6d830b3a151f/feedbax/train.py#L516-L521
As far as
TaskTrainer
is concerned, it would be sufficient to provide these pairings astuple[Callable, PyTree]
. However, in general it seems to make sense for the pairing to be a mapping -- first, because there should only be at most a single initialization provided for each substate. But also, if the user -- or a function infeedbax.plot
-- wants to access the intialization data from the trial specification, it is more convenient to write (say)trial_spec.init['mechanics.effector']
or eventrial_spec.init[lambda state: state.mechanics.effector]
than to have to figure out whichtuple[Callable, PyTree]
contains theCallable
that refers to the part of the state they're interested in.Unfortunately, lambdas cannot simply be used as keys in a mapping. If I define a dict with a key
lambda state: state.something
, and later try to get the value associated with a newly-defined keylambda s: s.something
, I'll encounter aKeyError
because lambdas are not hashed according to the function they represent, but by their memory address.Thus we use
WhereDict
, which is anOrderedDict
that enables limited use of lambdas as keys. In particular, it usesdis
to parse theLOAD_ATTR
operations in the lambda's bytecode, and constructs an equivalent string representation. For example,lambda x: x.foo.bar
is parsed as"foo.bar"
. This only works when the lambda takes a single argument, and returns a single (nested) attributed access on the argument.Both a lambda and its string representation can be used as
WhereDict
key:The downside is the overhead of
dis.Bytecode
, such thatWhereDict
is about 100x slower to construct, and 500-1000x slower to access, thanOrderedDict
. In practice this is not a big deal for our use case, as we only need to do a single construct and a single access on each training batch, leading to an overhead of about 125 us, where a batch normally takes at least 20,000 us. Also, it's unlikely the user will initialize more than a few substates separately, i.e. no more than a few entries perWhereDict
.Is there a better or faster way to do this?
I've considered that it might work to specify the state initialization as a prefix of the model state rather than using a lambda combined with a substate. However, this does not solve the access problem -- we'd still need to map to the prefixes with string keys like
"mechanics.effector"
so that the user could refer to them easily; however 1) those keys would no longer have a lawful relationship to the surgeries performed to assign the initial states, and 2) the user would still have to access the appropriate leaves from the prefix tree.