google-research / neuralgcm

Hybrid ML + physics model of the Earth's atmosphere
https://neuralgcm.readthedocs.io
Apache License 2.0
612 stars 65 forks source link

Reproducing the decadal simulations #99

Open ricardobarroslourenco opened 1 month ago

ricardobarroslourenco commented 1 month ago

Hello! Thanks for another great contribution.

While reading the NeuralGCM documentation (https://neuralgcm.readthedocs.io/en/latest/), I wasn't able to find instructions on reproducing the decadal simulations you provided in the Nature paper. Can you refer me to this?

yaniyuval commented 1 month ago

Hi, We ran decadal simulations with the 2.8 degree resolution NeuralGCM model. We ran 37 different initial conditions starting from 01/01/1980 (in intervals of 10 day, so the next initial condition we used was 01/11/1980...). We prescribe SST and sea-ice extent. In the paper (https://www.nature.com/articles/s41586-024-07744-y#Sec19) there is a brief description of the details.

ricardobarroslourenco commented 1 month ago

Hi @yaniyuval , thanks for the prompt answer. I just wonder on how to basically replicate the result you reported in the paper.

Would it be like you describe in the quick start guidelines (https://neuralgcm.readthedocs.io/en/latest/inference_demo.html#forecasting-quick-start)? Do you have a script you can share on these decadal simulations?

kochkov92 commented 1 month ago

Hi @ricardobarroslourenco!

The forecasting-quick-start does contain most of the pieces, but it doesn't address a few sharp bits that address large memory requirements.

For rollouts we indeed followed the same strategy as in the forecasting-quick-start guide, where we iteratively call model.unroll with final_state from the previous iteration as inputs and updated all_forcings inputs to provide correct sea_ice and sea_surface_temperature fields. Doing it in steps and saving chunks to disk (in a .zarr format) along the way makes it possible to avoid running out of RAM. Another memory saving trick we used is to precompute coarsened ERA5 data to avoid pulling full resolution forcings along the way.

Unfortunately the full script orchestrating this process currently uses internal tools that are not available externally, but we hope to port a version of that together with the updated API that is currently under active development. In the meantime we would be happy to provide more guidance.

ricardobarroslourenco commented 1 month ago

Thanks, @kochkov92, for the additional context.

I understand that there should be more complex optimization bits to share, especially when doing research in an industrial setting. However, experimental reproducibility is important, and the community has been discussing this issue for a long time ( remembering a funny example here [no offence intended]: https://youtu.be/N2zK3sAtr-4?si=T1ifnEf-8wjF-GWE ).

Perhaps an initial step would be to relax the constraint to run on a small scale, as you mentioned. Do you think that it would be possible to run the model in a Dask environment (instead of Jax)? As you are already using Xarray and Zarr, I suppose the so-called "Pangeo stack" has a chance here. This would enable the use of HPC premises (through Dask Jobqueue), which most of the community can access, for example.

kochkov92 commented 1 month ago

Thanks for sharing the example - it is funny!

Dask indeed might be a good option! I think as a first step we will try to get model forecast out that can be selectively reproduced using the forecasting-quick-start guide.

I do really hope that we can split up our inference code and make it publicly available (as well as the updated model API that I'm personally quite excited about).

ricardobarroslourenco commented 1 month ago

Nice. Let me know if you need help porting it to Dask. I have been using it recently on my dissertation work, and would be happy to help.

ShihengDuan commented 2 weeks ago

Hi, We ran decadal simulations with the 2.8 degree resolution NeuralGCM model. We ran 37 different initial conditions starting from 01/01/1980 (in intervals of 10 day, so the next initial condition we used was 01/11/1980...). We prescribe SST and sea-ice extent. In the paper (https://www.nature.com/articles/s41586-024-07744-y#Sec19) there is a brief description of the details.

Hi Yaniyuval, is the SST and sea ice for decadal simulation from ERA5 or CMIP6?

ShihengDuan commented 2 weeks ago

I'm getting NAN values when initializing with ERA5 starting from 1980-01-01 (with 1 day interval). Most conditions show NAN values after several years (e.g., around 1982-1984), with only a few show stable conditions up to 2018. Is this normal? I'm using a time varying SST and sea ice.

yaniyuval commented 2 weeks ago

@ShihengDuan , the SST and sea ice are taken from ERA5. We ran 37 initial conditions for 40 years and found that 22 initial conditions were stable. If you are getting a different result, my best guess is that the interpolation of SST that you use might not be consistent with what we use. @kochkov92, can you comment on the interpolation we used?

ShihengDuan commented 2 weeks ago

I'm using the code from the tutorial for interpolation:

era5_grid = spherical_harmonic.Grid(
    latitude_nodes=full_era5.sizes['latitude'],
    longitude_nodes=full_era5.sizes['longitude'],
    latitude_spacing=xarray_utils.infer_latitude_spacing(full_era5.latitude),
    longitude_offset=xarray_utils.infer_longitude_offset(full_era5.longitude),
)
regridder = horizontal_interpolation.ConservativeRegridder(
    era5_grid, model.data_coords.horizontal, skipna=True
)
eval_era5 = xarray_utils.regrid(sliced_era5, regridder)
eval_era5 = xarray_utils.fill_nan_with_nearest(eval_era5)

The SST and Sea Ice are indeed from another dataset (PCMDI). I tried to initialize from Jan 01 to Jan 30 1980, and only 4 shows stable results upto 2018 (with others showing NAN values).

kochkov92 commented 1 week ago

@ShihengDuan @yaniyuval - I believe this might be related to a bug related to the order in which fill_nan was used.

Here @ShihengDuan calls regrid (I believe it's called regrid_horizontal now) first, which updates the coordinates to be in deg and then calling fill_nan_with_nearest. This is a correct order of operations.

The checkpoints that we currently have up on cloud were trained with data processed as:

  1. Applying horizontal regrid without updating the lon/lat coordinates to deg (i.e. keeping them in rad units)
  2. Calling fill_nan_with_nearest
  3. Updating deg -> rad

This process if faulty and results in a few locations that have bogus filled values. (this affects what NN has learned at those locations).

For weather timescales the results should be extremely similar as it only affects a few boundary points where our ML components switch from ignoring SST values to other inputs, but there is a discrepancy which could affect stability. There was a comment on this in the original inference notebook that apparently got removed as we were cleaning things up.

I would suggest trying to do the "faulty" regridding to check if that changes stability. In a slightly longer run we will retrain and update our models to use the correctly filled values.

yaniyuval commented 1 week ago

@ShihengDuan, I would add that on a private communication I have discussed with people that did run NeuralGCM decadal simulations (to my understanding without correcting this interpolation issue) and report similar stability that we reported in the paper (but slightly different results due to different interpolation). So potentially (though I am not sure) there is some other issue in how you run the simulations. How frequently are you updating SST/sea ice?

ShihengDuan commented 1 week ago

@yaniyuval
I'm updating the SST/sea ice at daily interval (interpolated to daily and apply it as input forcings).

I am also trying to increase the initialization interval to 1 week (as I noticed in the paper it is 10 days). I'm guessing the 1-day interval is too short so that the sampling is not representative.

@kochkov92 I'll try to change the order of regridding. Could you provide a few lines of code for this "faulty" order?