google-deepmind / dm-haiku

JAX-based neural network library
https://dm-haiku.readthedocs.io
Apache License 2.0
2.88k stars 233 forks source link

Proper way to handle large state objects? #216

Open ttt733 opened 3 years ago

ttt733 commented 3 years ago

I was experiencing some slowness in my forward function, so I tried using the experimental visualization tool to debug it. One part of it in particular stuck out: Screenshot at 11-01-45 It goes on even further to the right. Up close, that blue line is a giant cond: Screenshot at 11-05-07

I checked my code, and I narrowed it down to this cond. Commenting it out removes that weird bit from the graph and speeds up the function's execution by ~20%.

graph_shape, current_index_shape, edge_counts_shape = graph_init_info
graph = hk.get_state('graph', shape=graph_shape, init=jnp.zeros)
current_index = hk.get_state('max_index', current_index_shape, dtype='int32', init=jnp.zeros)
edge_counts = hk.get_state('edge_counts', edge_counts_shape, dtype='int32', init=jnp.zeros)

# Reset the graph when it grows too large
graph, max_index, edge_counts = hk.cond(
    current_index > max_graph_nodes,
    lambda _: (jnp.zeros(graph_shape), 0, jnp.zeros(edge_counts_shape, dtype='int32')),
    lambda _: (graph, current_index, edge_counts),
    operand=None,
)

graph and edge_counts are very large ndarrays. current_index is a counter that's incremented each time the function's called, and max_graph_nodes is just a static int. (Basically, I'm trying to re-initialize the graph object once every few thousand times through the function.) graph stores values output by some of the transformers in my model, which I'm guessing is why you can see them feeding into the cond in the graph.

I could understand it if JAX was just compiling around these objects in a weird way, but the fact that the function slows down so much when adding a cond, which I don't expect to do much of anything too often, makes me think that I'm doing something wrong with the haiku state. Are there any special considerations that need to be taken with these objects to avoid the compiler behaving this way?

tomhennigan commented 3 years ago

Hi Trevor, hk.cond passes all Haiku state in and out of the cond to allow module parameters and state to be created/updated inside the branches. In this case, it looks to me like your branch functions do not actually call into Haiku modules, so you can probably just use regular jax.cond (which will not have all the extra operands).

JAX made a change earlier this year known as "omnistaging", I think this may allow us to just pass state out of the cond (rather than in). It is possible we could optimise the branch functions to also only return the updated state from the cond too. I'll take a look later this week. This meant if you did have to use hk.cond (e.g. because you used Haiku modules inside the branch functions) we could make them a lot more minimal.

ttt733 commented 3 years ago

Thanks for taking a look, @tomhennigan . I had switched everything in the transformed function to the haiku version, but I'll avoid doing so where I can for now. Switching it over to the jax.lax cond does remove the wild appendage I saw in the graph; however, the speed remains the same: I still see about a 20% speedup when removing the cond vs. keeping the jax or haiku conds there.

I'm not sure why it should be such a bottleneck - it's at the beginning of the function, and everything after it is just passing stuff forward through some transformer modules, at the moment. Does that seem like something I should expect, or is it worth looking into more? If you think that question's a better fit for discussion on the JAX repo, I can move it over there.

tomhennigan commented 3 years ago

That is odd, if the issue is not Haiku specific you may find better answers on the JAX repo (there are more people looking there).

One suggestion I would have is to try using a jax.lax.select(pred, true_val, false_val) instead of a cond(pred, true_fn, false_fn). We use this in our mixed precision examples (e.g. here) since we have a very cheap branch (does basically nothing) we rarely need to take and most of the time we take the expensive branch. So it ended up being faster to always compute both branches and then just select one output based on a boolean predicate.

AFAIK this is faster because a real cond (e.g. only execute one of the branches) on GPU requires a roundtrip of data from the GPU to the CPU in order to decide whether to take the LHS or RHS branch. While a select can be entirely computed on the GPU.

ttt733 commented 3 years ago

That's very useful info, thanks! Leaving this issue open since you mentioned looking at the JAX omnistaging option, but feel free to close it whenever you're ready.