Hi Flax Experts! I'm looking for the solution of how to perform modular checkpointing of a Flax module.
Specifically, we want to save or restore any layer or sub-module of a Flax model with Orbax checkpointing. This would require a meaningful way of mapping among Flax module <--> params <--> Orbax checkpointing. Currently, there is no easy way to get the first {path: Module} mapping. Could you provide some API or tooling to enable this functionality?
Here is a simple code snippet to demonstrate this FR.
# To save any Flax layer/module into separate checkpoint files.
class Model1(nn.Module):
@nn.compact
def __call__(self, x: jax.Array) -> jax.Array:
dense1 = nn.Dense(features=1, name='dense1')
dense2 = nn.Dense(features=1000, name='dense2')
dense3 = nn.Dense(features=1, name='dense3')
x = dense1(x)
# FR: Could we save `dense1` into one checkpoint file `model1/dense1`
x = dense2(x)
# FR: Could we save `dense2` into one checkpoint file `model1/dense2`
x = dense3(x)
# FR: Could we save `dense3` into one checkpoint file `model1/dense3`
return x
# To load any Flax layer/module from some separate checkpoint file, and make the params trainable or frozen.
class Model2(nn.Module):
@nn.compact
def __call__(self, x: jax.Array) -> jax.Array:
dense1 = nn.Dense(features=1, name='dense1')
dense2 = nn.Dense(features=1000, name='dense2')
dense3 = nn.Dense(features=1, name='dense3')
# FR: load `model1/dense1` this `dense1` layer in model2, and freeze its params
x = dense1(x)
# FR: load `model1/dense2` this `dense2` layer in model2, and freeze its params
x = dense2(x)
x = dense3(x)
return x
Hi Flax Experts! I'm looking for the solution of how to perform modular checkpointing of a Flax module.
Specifically, we want to save or restore any layer or sub-module of a Flax model with Orbax checkpointing. This would require a meaningful way of mapping among Flax module <-->
params
<--> Orbax checkpointing. Currently, there is no easy way to get the first{path: Module} mapping
. Could you provide some API or tooling to enable this functionality?Here is a simple code snippet to demonstrate this FR.