jax-ml / jax

Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more
http://jax.readthedocs.io/
Apache License 2.0
30.06k stars 2.75k forks source link

Slow compilation for functions acting on PyTrees #4667

Open nrontsis opened 3 years ago

nrontsis commented 3 years ago

I am working with discrete dynamical systems that depend on some parameters that are to be trained. The key point is that I want to express the state of the dynamical system not as an array, but as a Dict (or more generally as a PyTree) so that I can write e.g. state["position"] instead of the much less readable state[idx].

In the following minimal example, I demonstrate the issues I am running into when trying to do this using generic tools from jax.tree_util that act on a state PyTree. The compilation takes a lot of time and the resulting (pre-optimised) HLO files are up to 26 MBytes!

Implementation with PyTrees ```python import jax.numpy as np from jax.tree_util import tree_multimap, tree_map, tree_reduce, tree_flatten from jax import jit, partial, grad from jax.lax import scan @partial(jit, static_argnums=(0,)) def sum_of_squares_loss(dynamics: callable, states, parameters): initial_states = tree_map(lambda s: s[0], states) horizon_length = tree_flatten(states)[0][0].shape[0] predictions = propagate_dynamics(dynamics, initial_states, horizon_length, parameters) errors = tree_multimap(lambda s, p: np.sum((s - p)**2, keepdims=True), states, predictions) return tree_reduce(sum, errors)[0] sum_of_squares_loss_gradient = jit(grad(sum_of_squares_loss, argnums=2), static_argnums=(0,)) def propagate_dynamics(dynamics: callable, initial_states, horizon_length, parameters): return scan( f=lambda state, _: (dynamics(state, parameters), state), init=initial_states, xs=None, length=horizon_length )[1] ### Example call PROPAGATION_HORIZON_LENGTH = 200 STATE_DIMENSION = 100 STATE_NAMES = ["state_" + str(i) for i in range(STATE_DIMENSION)] def example_dynamics(states, parameters): return tree_multimap(lambda s, p: 0.999*s + 1e-3*p, states, parameters) states = {name: np.ones(PROPAGATION_HORIZON_LENGTH) for name in STATE_NAMES} parameters = {name: 1.0 for name in STATE_NAMES} # Compile functions sum_of_squares_loss(example_dynamics, states, parameters) sum_of_squares_loss_gradient(example_dynamics, states, parameters) ```

Resulting XLA dump for the above

For comparison, an equivalent version of the above example that only acts on arrays results in (pre-optimised) HLO files of up to 50KBytes.

Fully vectorised code (no pytrees/dicts involved) ```python import jax.numpy as np from jax.tree_util import tree_multimap, tree_map, tree_reduce, tree_flatten from jax import jit, partial, grad from jax.lax import scan @partial(jit, static_argnums=(0,)) def sum_of_squares_loss(dynamics: callable, states, parameters): initial_states = states[0] horizon_length = states.shape[0] predictions = propagate_dynamics(dynamics, initial_states, horizon_length, parameters) return np.sum(predictions - states) sum_of_squares_loss_gradient = jit(grad(sum_of_squares_loss, argnums=2), static_argnums=(0,)) def propagate_dynamics(dynamics: callable, initial_states, horizon_length, parameters): return scan( f=lambda state, _: (dynamics(state, parameters), state), init=initial_states, xs=None, length=horizon_length )[1] ### Example call PROPAGATION_HORIZON_LENGTH = 200 STATE_DIMENSION = 100 STATE_NAMES = ["state_" + str(i) for i in range(STATE_DIMENSION)] def example_dynamics(states, parameters): return 0.999*states + 1e-3*parameters states = np.ones((PROPAGATION_HORIZON_LENGTH, STATE_DIMENSION)) parameters = np.ones((STATE_DIMENSION,)) # Compile function sum_of_squares_loss(example_dynamics, states, parameters) sum_of_squares_loss_gradient(example_dynamics, states, parameters) ```

Resulting XLA dump for the above

Finally, as a comparison I considered flattening and unflattening the states into a dict inside every call of the dynamics function. I thought that this might be inefficient due to this comment. However, it performs better than the PyTree version, resulting in (pre-optimised) HLO files of up to 1.1Mbytes.

Explicitly convert from arrays to dicts and back at the innermost function ### ```python import jax.numpy as np from jax.tree_util import tree_multimap, tree_map, tree_reduce, tree_flatten from jax import jit, partial, grad from jax.lax import scan @partial(jit, static_argnums=(0,)) def sum_of_squares_loss(dynamics: callable, states, parameters): initial_states = states[0] horizon_length = states.shape[0] predictions = propagate_dynamics(dynamics, initial_states, horizon_length, parameters) return np.sum(predictions - states) sum_of_squares_loss_gradient = jit(grad(sum_of_squares_loss, argnums=2), static_argnums=(0,)) def propagate_dynamics(dynamics: callable, initial_states, horizon_length, parameters): return scan( f=lambda state, _: (dynamics(state, parameters), state), init=initial_states, xs=None, length=horizon_length )[1] ### Example call PROPAGATION_HORIZON_LENGTH = 200 STATE_DIMENSION = 100 STATE_NAMES = ["state_" + str(i) for i in range(STATE_DIMENSION)] def example_dynamics(states, parameters): states_dict = {n: s for n, s in zip(STATE_NAMES, states)} parameters_dict = {n: p for n, p in zip(STATE_NAMES, parameters)} new_states_dict = {n: 0.999*states_dict[n] + 1e-3*parameters_dict[n] for n in STATE_NAMES} return np.array(list(new_states_dict.values())) states = np.ones((PROPAGATION_HORIZON_LENGTH, STATE_DIMENSION)) parameters = np.ones((STATE_DIMENSION,)) # Compile functions sum_of_squares_loss(example_dynamics, states, parameters) sum_of_squares_loss_gradient(example_dynamics, states, parameters) ```

Resulting XLA dump for the above

So my questions is: what is the best way to write/wrap a function that acts on a PyTree, without having prohibitively large compile times?

esbenscriver commented 6 months ago

Did you find any solution to this problem? And do you know why the pytrees use so much memory?

jakevdp commented 6 months ago

I believe the issue is not with pytrees per se, but rather with defining hundreds of array arguments, and compiling functions which take hundreds of array arguments. Generally speaking, compilation costs should not scale with the size of the individual arrays being operated on, but we do expect them to scale with the number of array objects being passed to the function. In the best case, you might achieve linear scaling – but I suspect in reality you'll see closer to quadratic scaling with the number of array inputs. This is not unexpected, because it will generally lead to much larger programs which require much more logic to optimize.

Does that answer your question?