google-deepmind / graphcast

Apache License 2.0
4.37k stars 538 forks source link

How to save the model #34

Closed AndrewYangnb closed 6 months ago

AndrewYangnb commented 7 months ago

I have successfully run graphcast_demo.ipynb locally, but how do I save the model parameters after the final training using the sample dataset?

I try using code:

# @title Autoregressive rollout (keep the loop in JAX)
print("Inputs:  ", train_inputs.dims.mapping)
print("Targets: ", train_targets.dims.mapping)
print("Forcings:", train_forcings.dims.mapping)

predictions = run_forward_jitted(
    rng=jax.random.PRNGKey(0),
    inputs=train_inputs,
    targets_template=train_targets * np.nan,
    forcings=train_forcings)
predictions

with open(f"dm_graphcast/params/test.npz", "wb") as f:
    checkpoint.dump(f, graphcast.CheckPoint)

But it doesn't work.

tewalds commented 6 months ago

You need to pass a graphcast.CheckPoint object, not the class definition. In the colab that is the ckpt variable, but if you're not changing it, you can also just download the file directly. If you want to change it (eg fine tune it), you'll need to construct a new object to pass to checkpoint.dump.

AndrewYangnb commented 6 months ago

You need to pass a graphcast.CheckPoint object, not the class definition. In the colab that is the ckpt variable, but if you're not changing it, you can also just download the file directly. If you want to change it (eg fine tune it), you'll need to construct a new object to pass to checkpoint.dump.

ok, thank you, i have got it.