google-research / neuralgcm

Hybrid ML + physics model of the Earth's atmosphere
https://neuralgcm.readthedocs.io
Apache License 2.0
666 stars 74 forks source link

Please help me to set stochastic forecast! Thank you! #19

Closed weatherforecasterwhai closed 6 months ago

weatherforecasterwhai commented 7 months ago

By running deterministic neuralGCM, it's much better than FuXi, Graphcast and Fourcastnet in p_minus_e. But neuralGCM still predicts worse in two days when running daily case studies. So I want to try the stochastic neuralGCM. Just changing the .pkl name to stochastic seems not enough, it outputs just the deterministic predictions. I've tried to add the "model_config_str" like this: ... 'CHECKPOINT_MULTISTEP = True,' 'StochasticPhysicsParameterizationStep.checkpoint_substep = True',...

It still doesn't work. Please help me ! Thank you.

kochkov92 commented 7 months ago

@weatherforecasterwhai - thanks for raising this up. Can you please confirm that you are providing different rng_key when calling encode in your experiment?

In our stochastic models randomness is fully controlled by the rng_key, hence given the same keys across rollouts the outputs should be identical.

predictions = []
for rng_key in jax.random.split(jax.random.PRNGKey(seed), ensemble_size):
  init_state = model.encode(input_data, input_forcings, **rng_key**)
  predictions.append(model.data_to_xarray(model.unroll(init_state, forcings, ...))
combined_ds = xarray.concat(predictions, 'ensemble')
weatherforecasterwhai commented 7 months ago

Thank you very much for your quick reply! Of course I don’t know to use keys, I will try in your way soon. Really need more detailed documents to understand more about how could you have done it so well! I have tried my best to read the codes but in vain.

-- 发自我的网易邮箱手机智能版

在 2024-04-08 21:17:44,"Dima Kochkov" @.***> 写道:

@weatherforecasterwhai - thanks for raising this up. Can you please confirm that you are providing different rng_key when calling encode in your experiment?

In our stochastic models randomness is fully controlled by the rng_key, hence given the same keys across rollouts the outputs should be identical.

predictions = [] for rng_key in jax.random.split(jax.random.PRNGKey(seed), ensemble_size): init_state = model.encode(input_data, input_forcings, rng_key) predictions.append(model.data_to_xarray(model.unroll(init_state, forcings, ...)) combined_ds = xarray.concat(predictions, 'ensemble')

— Reply to this email directly, view it on GitHub, or unsubscribe. You are receiving this because you were mentioned.Message ID: @.***>