Closed tybalex closed 5 months ago
Hi @tybalex, thanks for the issue! You can use orbax
to save the checkpoint as described in this example - https://flax.readthedocs.io/en/latest/guides/training_techniques/use_checkpointing.html#with-orbax.
This is the method that is used in the simple_run_jax
example. What is the issue you're encountering with using that?
Hi @tybalex, thanks for the issue! You can use
orbax
to save the checkpoint as described in this example - https://flax.readthedocs.io/en/latest/guides/training_techniques/use_checkpointing.html#with-orbax.This is the method that is used in the
simple_run_jax
example. What is the issue you're encountering with using that?
Thank you for the instruction. so I saved the checkpoint with
trained_params = train_loop(
model=model,
params=params,
optimizer=optimizer,
train_ds=train_ds,
validation_ds=validation_ds,
num_steps=num_steps,
)
ckpt = {
"model": trained_params,
"config": config,
}
orbax_checkpointer = orbax.checkpoint.PyTreeCheckpointer()
save_args = orbax_utils.save_args_from_target(ckpt)
orbax_checkpointer.save("/content/recurrentgemma-2b-it", ckpt, save_args=save_args)
and then load it again:
params = recurrentgemma.load_parameters("/content/recurrentgemma-2b-it", "single_device")
config = recurrentgemma.GriffinConfig.from_flax_params_or_variables(params)
model = recurrentgemma.Griffin(config)
Then I got this error:
---------------------------------------------------------------------------
KeyError Traceback (most recent call last)
2 # params = recurrentgemma.load_parameters(ckpt_path, "single_device")
3 params = recurrentgemma.load_parameters("/content/recurrentgemma-2b-it", "single_device")
----> 4 config = recurrentgemma.GriffinConfig.from_flax_params_or_variables(params)
5 model = recurrentgemma.Griffin(config)
[/usr/local/lib/python3.10/dist-packages/recurrentgemma/common.py](https://22xyh3vscqc-496ff2e9c6d22116-0-colab.googleusercontent.com/outputframe.html?vrz=colab_20240409-060157_RC00_623090811#) in from_flax_params_or_variables(cls, flax_params_or_variables, max_sequence_length, preset)
152 params = flax_params_or_variables
153
--> 154 vocab_size, width = params["embedder"]["input_embedding"].shape
155 mlp_exp_width = params["blocks.0"]["mlp_block"]["ffw_up"]["w"].shape[-1]
156
KeyError: 'embedder'
Looks like something is missing.?
In general I just don't know what should I save in the ckpt
dict, so if would be nice if I can find an example to save the checkpoints in your example notebooks.
You can in general save any pytree with orbax checkpoint.
We've added a simple util for saving that should be useful to you - https://github.com/google-deepmind/recurrentgemma/blob/main/recurrentgemma/jax/utils.py#L25.
@tybalex the function GriffinConfig.from_flax_params_or_variables
expects to be passed either direct dict of parameters, or as nush did a dict that has a key "params". In your case the parameters are under "model" key so you need to do smth like:
params = recurrentgemma.load_parameters("/content/recurrentgemma-2b-it", "single_device")["model"]
config = recurrentgemma.GriffinConfig.from_flax_params_or_variables(params)
model = recurrentgemma.Griffin(config)
Or save directly the trained_params
Thank you both for the help. save directly the trained_params
works for me !
I was able to fine-tune recurrentgemma-2b-it, using Jax, by loading the model with
recurrentgemma.load_parameters()
following this colab example. But once I have a trained model/checkpoint, I can't find anything yet to save the params. I would guess it should be something like https://github.com/google-deepmind/recurrentgemma/blob/main/examples/simple_run_jax.py? Although this one doesn't work for me.