google-deepmind / graphcast

Apache License 2.0
4.36k stars 537 forks source link

How to fine-tune graphcast? #43

Open AndrewYangnb opened 6 months ago

AndrewYangnb commented 6 months ago

Is there any sample code for fine-tuning graphcast?

illuSION-crypto commented 6 months ago

I also want to know. I can run the notebook code, get loss and grads, but I find that the loss doesn't backprop. By the way, how to save model randomly initialized is also not mentioned.

ChrisAGBlake commented 6 months ago

I managed to get it working with a few modifications to the example notebook code. Here are the changes I made (ignoring all unchanged code from the example):

import optax

# modify the gradients function signature
def grads_fn(params, state, inputs, targets, forcings, model_config, task_config):
    def _aux(params, state, i, t, f):
        (loss, diagnostics), next_state = loss_fn.apply(params, state, jax.random.PRNGKey(0), model_config, task_config, i, t, f)
        return loss, (diagnostics, next_state)
    (loss, (diagnostics, next_state)), grads = jax.value_and_grad(_aux, has_aux=True)(params, state, inputs, targets, forcings)
    return loss, diagnostics, next_state, grads

# remove `with_params` from jitted grads function
grads_fn_jitted = jax.jit(with_configs(grads_fn))

# setup optimiser
lr = 1e-3
optimiser = optax.adam(lr, b1=0.9, b2=0.999, eps=1e-8)
opt_state = optimiser.init(params)

# calculate loss and gradients
loss, diagnostics, next_state, grads = grads_fn_jitted(params, state, inputs, targets, forcings)

# update
updates, opt_state = optimiser.update(grads, opt_state)
params = optax.apply_updates(params, updates)
mjwillson commented 5 months ago

I'm afraid we don't provide a training script in this codebase as our training setup is quite tied to internal infrastructure, but I think we give enough to construct one yourself if you want one, and the above is a good start. Note you may struggle to fine-tune the 0.25deg model unrolled to 3 days without extra tricks to save GPU/TPU RAM (see other issue on this).

I'll leave this open in case any others want to contribute example code.

monte-flora commented 4 months ago

For saving and loading the model parameters, I came up with these functions. @illuSION-crypto

import jax
import numpy as np
import jax.numpy as jnp
import os 

def flatten_dict(d, parent_key='', sep='//'):
    items = []
    for k, v in d.items():
        new_key = f"{parent_key}{sep}{k}" if parent_key else k
        if isinstance(v, dict):
            items.extend(flatten_dict(v, new_key, sep=sep).items())
        else:
            items.append((new_key, v))
    return dict(items)

def save_model_params(d, file_path):
    flat_dict = flatten_dict(d)
    # Convert JAX arrays to NumPy for saving
    np_dict = {k: np.array(v) if isinstance(v, jnp.ndarray) else v for k, v in flat_dict.items()}
    np.savez(file_path, **np_dict)

params_path = os.path.join('path/to/params', 'params.npz')
save_model_params(params, params_path)

def unflatten_dict(d, sep='//'):
    result_dict = {}
    for flat_key, value in d.items():
        keys = flat_key.split(sep)
        d = result_dict
        for key in keys[:-1]:
            if key not in d:
                d[key] = {}
            d = d[key]
        d[keys[-1]] = value
    return result_dict

def load_model_params(file_path):
    with np.load(file_path, allow_pickle=True) as npz_file:
        # Convert NumPy arrays back to JAX arrays
        jax_dict = {k: jnp.array(v) for k, v in npz_file.items()}
    return unflatten_dict(jax_dict)

params = load_model_params(params_path)
monte-flora commented 4 months ago

Thanks @ChrisAGBlake for sharing! Is it possible to modify your script to allow for multiple GPUs?

ChrisAGBlake commented 3 months ago

Sure, I used jax.pmap versions of functions to distribute across multiple GPUs

    # setup the update function
    @functools.partial(xarray_jax.pmap, dim='device', axis_name='device')
    def multi_gpu_update(params, state, opt_state, inputs, targets, forcings):

        # calculate loss and gradients
        loss, diagnostics, next_state, grads = grads_fn_jitted(params, state, inputs, targets, forcings)

        # combine the gradients across devices
        grads = jax.lax.pmean(grads, axis_name='device')

        # combine the loss across devices
        loss = jax.lax.pmean(loss, axis_name='device')

        # update
        updates, new_opt_state = optimiser.update(grads, opt_state)
        new_params = optax.apply_updates(params, updates)

        return new_params, loss, new_opt_state

    # setup the loss function for evaluation
    @functools.partial(xarray_jax.pmap, dim='device', axis_name='device')
    def multi_gpu_loss(params, state, inputs, targets, forcings):

        # calculate loss
        (loss, diagnostics), next_state = loss_fn_jitted(params, state, jax.random.PRNGKey(0), inputs, targets, forcings)

        # combine the loss across devices
        loss = jax.lax.pmean(loss, axis_name='device')

        return loss
vsansi commented 2 months ago

Thanks @ChrisAGBlake and @monte-flora for sharing, i have saved and loaded the model respectivelly with the save_model_params and load_model_params functions, but I have trouble doing predictions with it, if i use the Autoregressive rollout (loop in python) and loss computation code it seems to be loss of the randomly initialized model. The loss should be way lower than what I got.

        prediction = rollout.chunked_prediction(
                      run_forward_jitted,
                      rng=jax.random.PRNGKey(0),
                      inputs=train_inputs,
                      targets_template=train_targets * np.nan,
                      forcings=train_forcings)

        predictions.append(prediction)

        if i in [0,1,2]:
            # @title Loss computation (autoregressive loss over multiple steps)
            loss, diagnostics = loss_fn_jitted(
                rng=jax.random.PRNGKey(0),
                inputs=train_inputs,
                targets=train_targets,
                forcings=train_forcings)
            print("Loss:", float(loss))

It is as if the new loaded parameters are not taken into account in the predictions. Any clue on how to make this work ? Thank you

AndrewYangnb commented 2 months ago

Thanks @ChrisAGBlake. I encountered a problem. My GPU memory only has 12 G, and OOM occurs when calculating loss and gradient. Is there any solution?

ChrisAGBlake commented 2 months ago

You'll probably need more GPU memory than that unless you're trying to train a low resolution model. One workaround is to offload to the CPU which you can do by putting the following couple of lines at the top of your script.

import os
os.environ['XLA_PYTHON_CLIENT_MEM_FRACTION'] = '2.0' 

This will try and use your CPU to double the amount of memory you have available from the GPU - 24GB in total. You can adjust '2.0' to whatever value is required to be able to do the update (within your RAM availability).

I have found that I require ~80GB with a model at 0.25 deg resolution.

Another solution is just to use vast.ai, AWS, GCP etc

ChrisAGBlake commented 2 months ago

@vsansi I load the trained model like this

  # load the model
  with open(checkpoint_file, 'rb') as f:
      ckpt = checkpoint.load(f, graphcast.CheckPoint)
  params = ckpt.params
  state = {}
  model_config = ckpt.model_config
  task_config = ckpt.task_config

Then I generate predictions like this:

    # run forward autoressively to get the predictions
    run_forward_jitted = drop_state(with_params(jax.jit(with_configs(run_forward.apply))))
    predictions = rollout.chunked_prediction(
        run_forward_jitted,
        rng=jax.random.PRNGKey(0),
        inputs=inputs,
        targets_template=targets * np.nan,
        forcings=forcings)

    # write out the predictions
    pred_file = f'{save_dir}/{date.strftime("%Y-%m-%dT%H")}_predictions.nc'
    predictions.to_netcdf(pred_file)

Hope this helps.