google / flax

Flax is a neural network library for JAX that is designed for flexibility.
https://flax.readthedocs.io
Apache License 2.0
6k stars 633 forks source link

[Feature Request] Modular checkpointing of Flax module #3653

Open yyyliang opened 8 months ago

yyyliang commented 8 months ago

Hi Flax Experts! I'm looking for the solution of how to perform modular checkpointing of a Flax module.

Specifically, we want to save or restore any layer or sub-module of a Flax model with Orbax checkpointing. This would require a meaningful way of mapping among Flax module <--> params <--> Orbax checkpointing. Currently, there is no easy way to get the first {path: Module} mapping. Could you provide some API or tooling to enable this functionality?

Here is a simple code snippet to demonstrate this FR.

# To save any Flax layer/module into separate checkpoint files.
 class Model1(nn.Module):
  @nn.compact
  def __call__(self, x: jax.Array) -> jax.Array:
    dense1 = nn.Dense(features=1, name='dense1') 
    dense2 = nn.Dense(features=1000, name='dense2')
    dense3 = nn.Dense(features=1, name='dense3')
    x = dense1(x)
    # FR: Could we save `dense1` into one checkpoint file `model1/dense1`
    x = dense2(x)
    # FR: Could we save `dense2` into one checkpoint file `model1/dense2`
    x = dense3(x)
    # FR: Could we save `dense3` into one checkpoint file `model1/dense3`
    return x

# To load any Flax layer/module from some separate checkpoint file, and make the params trainable or frozen.
class Model2(nn.Module):
  @nn.compact
  def __call__(self, x: jax.Array) -> jax.Array:
    dense1 = nn.Dense(features=1, name='dense1') 
    dense2 = nn.Dense(features=1000, name='dense2')
    dense3 = nn.Dense(features=1, name='dense3')
    # FR: load `model1/dense1` this `dense1` layer in model2, and freeze its params
    x = dense1(x)
    # FR: load `model1/dense2` this `dense2` layer in model2, and freeze its params
    x = dense2(x)
    x = dense3(x)
    return x
chiamp commented 8 months ago

3654 should solve this problem