google-deepmind / recurrentgemma

Open weights language model from Google DeepMind, based on Griffin.
Apache License 2.0
597 stars 25 forks source link

Function to save trained parameters? #1

Closed tybalex closed 5 months ago

tybalex commented 5 months ago

I was able to fine-tune recurrentgemma-2b-it, using Jax, by loading the model with recurrentgemma.load_parameters() following this colab example. But once I have a trained model/checkpoint, I can't find anything yet to save the params. I would guess it should be something like https://github.com/google-deepmind/recurrentgemma/blob/main/examples/simple_run_jax.py? Although this one doesn't work for me.

Nush395 commented 5 months ago

Hi @tybalex, thanks for the issue! You can use orbax to save the checkpoint as described in this example - https://flax.readthedocs.io/en/latest/guides/training_techniques/use_checkpointing.html#with-orbax.

This is the method that is used in the simple_run_jax example. What is the issue you're encountering with using that?

tybalex commented 5 months ago

Hi @tybalex, thanks for the issue! You can use orbax to save the checkpoint as described in this example - https://flax.readthedocs.io/en/latest/guides/training_techniques/use_checkpointing.html#with-orbax.

This is the method that is used in the simple_run_jax example. What is the issue you're encountering with using that?

Thank you for the instruction. so I saved the checkpoint with

trained_params = train_loop(
    model=model,
    params=params,
    optimizer=optimizer,
    train_ds=train_ds,
    validation_ds=validation_ds,
    num_steps=num_steps,
)

ckpt = {
        "model": trained_params,
        "config": config,
    }
orbax_checkpointer = orbax.checkpoint.PyTreeCheckpointer()
save_args = orbax_utils.save_args_from_target(ckpt)
orbax_checkpointer.save("/content/recurrentgemma-2b-it", ckpt, save_args=save_args)

and then load it again:

params =  recurrentgemma.load_parameters("/content/recurrentgemma-2b-it", "single_device")
config = recurrentgemma.GriffinConfig.from_flax_params_or_variables(params)
model = recurrentgemma.Griffin(config)

Then I got this error:

---------------------------------------------------------------------------
KeyError                                  Traceback (most recent call last)
      2 # params =  recurrentgemma.load_parameters(ckpt_path, "single_device")
      3 params =  recurrentgemma.load_parameters("/content/recurrentgemma-2b-it", "single_device")
----> 4 config = recurrentgemma.GriffinConfig.from_flax_params_or_variables(params)
      5 model = recurrentgemma.Griffin(config)

[/usr/local/lib/python3.10/dist-packages/recurrentgemma/common.py](https://22xyh3vscqc-496ff2e9c6d22116-0-colab.googleusercontent.com/outputframe.html?vrz=colab_20240409-060157_RC00_623090811#) in from_flax_params_or_variables(cls, flax_params_or_variables, max_sequence_length, preset)
    152       params = flax_params_or_variables
    153 
--> 154     vocab_size, width = params["embedder"]["input_embedding"].shape
    155     mlp_exp_width = params["blocks.0"]["mlp_block"]["ffw_up"]["w"].shape[-1]
    156 

KeyError: 'embedder'

Looks like something is missing.? In general I just don't know what should I save in the ckpt dict, so if would be nice if I can find an example to save the checkpoints in your example notebooks.

Nush395 commented 5 months ago

You can in general save any pytree with orbax checkpoint.

We've added a simple util for saving that should be useful to you - https://github.com/google-deepmind/recurrentgemma/blob/main/recurrentgemma/jax/utils.py#L25.

botev commented 5 months ago

@tybalex the function GriffinConfig.from_flax_params_or_variables expects to be passed either direct dict of parameters, or as nush did a dict that has a key "params". In your case the parameters are under "model" key so you need to do smth like:

params =  recurrentgemma.load_parameters("/content/recurrentgemma-2b-it", "single_device")["model"]
config = recurrentgemma.GriffinConfig.from_flax_params_or_variables(params)
model = recurrentgemma.Griffin(config)

Or save directly the trained_params

tybalex commented 5 months ago

Thank you both for the help. save directly the trained_params works for me !