google / flax

Flax is a neural network library for JAX that is designed for flexibility.
https://flax.readthedocs.io
Apache License 2.0
6.15k stars 648 forks source link

Flax NNX and Orbax Checkpointing require hacking to work together #4383

Open hdrwilkinson opened 1 week ago

hdrwilkinson commented 1 week ago

I'm building a system using flax.nnx and orbax.checkpointing. However, it is overly complicated on how to save and restore models due to the new jax.random.key() being used in flax.nnx rather than jax.random.PRNGkey().

I have had to create a workaround where all layers with rng and key in their path are changed from dtype=key<fry> to a format appropriate for saving. Then, upon restoration, they need to be shanged back.

I am attaching a link to a notebook explaining what I've done but I would be keen to hear if there are simpler workarounds? Or, preferably, if there is a way to simple save and restore models?

https://colab.research.google.com/drive/1ozln9ejG7eRtxvbkqHYU3K6OyPvveH9w?usp=sharing

Note: I am also adding an issue to orbax to see if there is a fix their side (#1337).

cgarciae commented 1 week ago

Thank you! I'll contact the Orbax team to see if they can fix this on their end.

mishmish66 commented 1 week ago

Hey! Here's a quick and dirty workaround.

Generally the idea is to use nnx.split with the NNX filter functionality to split the nnx.RngState types out of the state and then not save those.

graphdef, rng_state, other_state = nnx.split(model, nnx.RngState, ...)

and then just saving the other_state instead of the full thingy. I've edited your colab notebook to demonstrate this.

This means that RNG state will not be restored, which might be sub-optimal for certain scenarios but should work for most stuff. Hope it helps!

jkyl commented 12 hours ago

Another workaround for Dropout layers, and maybe custom layers too if they follow the same pattern, is to initialize them without the rngs arg, and only pass the RNG at __call__ time, like so:

import jax.numpy as jnp
import orbax.checkpoint as ocp
from flax import nnx

# Init dropout without rng arg.
model = nnx.Dropout(0.5)

# Pass RNG at call time.
output = model(jnp.ones(()), rngs=nnx.Rngs(0))

# This now works.
ckpt_dir = ocp.test_utils.erase_and_create_empty("/tmp/my-checkpoints/")
checkpointer = ocp.StandardCheckpointer()
checkpointer.save(ckpt_dir / "state", nnx.split(model)[1])

Versus if the RNG is supplied at initialization, the last line throws the following:

TypeError: Cannot interpret 'key<fry>' as a data type

But, this is only a workaround, as the RNG state will still not be serialized, and it makes for a more verbose call signature.