Open nrontsis opened 4 years ago
Did you find any solution to this problem? And do you know why the pytrees use so much memory?
I believe the issue is not with pytrees per se, but rather with defining hundreds of array arguments, and compiling functions which take hundreds of array arguments. Generally speaking, compilation costs should not scale with the size of the individual arrays being operated on, but we do expect them to scale with the number of array objects being passed to the function. In the best case, you might achieve linear scaling – but I suspect in reality you'll see closer to quadratic scaling with the number of array inputs. This is not unexpected, because it will generally lead to much larger programs which require much more logic to optimize.
Does that answer your question?
I am working with discrete dynamical systems that depend on some parameters that are to be trained. The key point is that I want to express the state of the dynamical system not as an array, but as a
Dict
(or more generally as aPyTree
) so that I can write e.g.state["position"]
instead of the much less readablestate[idx]
.In the following minimal example, I demonstrate the issues I am running into when trying to do this using generic tools from
jax.tree_util
that act on astate
PyTree
. The compilation takes a lot of time and the resulting (pre-optimised) HLO files are up to26
MBytes!Implementation with PyTrees
```python import jax.numpy as np from jax.tree_util import tree_multimap, tree_map, tree_reduce, tree_flatten from jax import jit, partial, grad from jax.lax import scan @partial(jit, static_argnums=(0,)) def sum_of_squares_loss(dynamics: callable, states, parameters): initial_states = tree_map(lambda s: s[0], states) horizon_length = tree_flatten(states)[0][0].shape[0] predictions = propagate_dynamics(dynamics, initial_states, horizon_length, parameters) errors = tree_multimap(lambda s, p: np.sum((s - p)**2, keepdims=True), states, predictions) return tree_reduce(sum, errors)[0] sum_of_squares_loss_gradient = jit(grad(sum_of_squares_loss, argnums=2), static_argnums=(0,)) def propagate_dynamics(dynamics: callable, initial_states, horizon_length, parameters): return scan( f=lambda state, _: (dynamics(state, parameters), state), init=initial_states, xs=None, length=horizon_length )[1] ### Example call PROPAGATION_HORIZON_LENGTH = 200 STATE_DIMENSION = 100 STATE_NAMES = ["state_" + str(i) for i in range(STATE_DIMENSION)] def example_dynamics(states, parameters): return tree_multimap(lambda s, p: 0.999*s + 1e-3*p, states, parameters) states = {name: np.ones(PROPAGATION_HORIZON_LENGTH) for name in STATE_NAMES} parameters = {name: 1.0 for name in STATE_NAMES} # Compile functions sum_of_squares_loss(example_dynamics, states, parameters) sum_of_squares_loss_gradient(example_dynamics, states, parameters) ```Resulting XLA dump for the above
For comparison, an equivalent version of the above example that only acts on arrays results in (pre-optimised) HLO files of up to 50KBytes.
Fully vectorised code (no pytrees/dicts involved)
```python import jax.numpy as np from jax.tree_util import tree_multimap, tree_map, tree_reduce, tree_flatten from jax import jit, partial, grad from jax.lax import scan @partial(jit, static_argnums=(0,)) def sum_of_squares_loss(dynamics: callable, states, parameters): initial_states = states[0] horizon_length = states.shape[0] predictions = propagate_dynamics(dynamics, initial_states, horizon_length, parameters) return np.sum(predictions - states) sum_of_squares_loss_gradient = jit(grad(sum_of_squares_loss, argnums=2), static_argnums=(0,)) def propagate_dynamics(dynamics: callable, initial_states, horizon_length, parameters): return scan( f=lambda state, _: (dynamics(state, parameters), state), init=initial_states, xs=None, length=horizon_length )[1] ### Example call PROPAGATION_HORIZON_LENGTH = 200 STATE_DIMENSION = 100 STATE_NAMES = ["state_" + str(i) for i in range(STATE_DIMENSION)] def example_dynamics(states, parameters): return 0.999*states + 1e-3*parameters states = np.ones((PROPAGATION_HORIZON_LENGTH, STATE_DIMENSION)) parameters = np.ones((STATE_DIMENSION,)) # Compile function sum_of_squares_loss(example_dynamics, states, parameters) sum_of_squares_loss_gradient(example_dynamics, states, parameters) ```Resulting XLA dump for the above
Finally, as a comparison I considered flattening and unflattening the states into a dict inside every call of the dynamics function. I thought that this might be inefficient due to this comment. However, it performs better than the PyTree version, resulting in (pre-optimised) HLO files of up to 1.1Mbytes.
Explicitly convert from arrays to dicts and back at the innermost function
### ```python import jax.numpy as np from jax.tree_util import tree_multimap, tree_map, tree_reduce, tree_flatten from jax import jit, partial, grad from jax.lax import scan @partial(jit, static_argnums=(0,)) def sum_of_squares_loss(dynamics: callable, states, parameters): initial_states = states[0] horizon_length = states.shape[0] predictions = propagate_dynamics(dynamics, initial_states, horizon_length, parameters) return np.sum(predictions - states) sum_of_squares_loss_gradient = jit(grad(sum_of_squares_loss, argnums=2), static_argnums=(0,)) def propagate_dynamics(dynamics: callable, initial_states, horizon_length, parameters): return scan( f=lambda state, _: (dynamics(state, parameters), state), init=initial_states, xs=None, length=horizon_length )[1] ### Example call PROPAGATION_HORIZON_LENGTH = 200 STATE_DIMENSION = 100 STATE_NAMES = ["state_" + str(i) for i in range(STATE_DIMENSION)] def example_dynamics(states, parameters): states_dict = {n: s for n, s in zip(STATE_NAMES, states)} parameters_dict = {n: p for n, p in zip(STATE_NAMES, parameters)} new_states_dict = {n: 0.999*states_dict[n] + 1e-3*parameters_dict[n] for n in STATE_NAMES} return np.array(list(new_states_dict.values())) states = np.ones((PROPAGATION_HORIZON_LENGTH, STATE_DIMENSION)) parameters = np.ones((STATE_DIMENSION,)) # Compile functions sum_of_squares_loss(example_dynamics, states, parameters) sum_of_squares_loss_gradient(example_dynamics, states, parameters) ```Resulting XLA dump for the above
So my questions is: what is the best way to write/wrap a function that acts on a PyTree, without having prohibitively large compile times?