Open jswhit opened 6 months ago
Are you interested in particular version of NeuralGCM?
The internal state of the model (outputs of PressureLevelModel.encode
, inputs/outputs of encode
and inputs/first output of unroll
) are actually already on sigma coordinates, e.g.,
>>> jax.tree_util.tree_map(np.shape, final_state)
ModelState(state=StateWithTime(vorticity=(32, 256, 129), divergence=(32, 256, 129), temperature_variation=(32, 256, 129), log_surface_pressure=(1, 256, 129), sim_time=(), tracers={'specific_cloud_ice_water_content': (32, 256, 129), 'specific_cloud_liquid_water_content': (32, 256, 129), 'specific_humidity': (32, 256, 129)}), memory=StateWithTime(vorticity=(32, 256, 129), divergence=(32, 256, 129), temperature_variation=(32, 256, 129), log_surface_pressure=(1, 256, 129), sim_time=(), tracers={'specific_cloud_ice_water_content': (32, 256, 129), 'specific_cloud_liquid_water_content': (32, 256, 129), 'specific_humidity': (32, 256, 129)}), diagnostics={}, randomness=RandomnessState(core=(256, 128), nodal_value=(256, 128), modal_value=(256, 129)))
These variables are stored as spherical harmonic coefficient, so they need to be transformed back into real space for visualization, e.g., to visualize temperature near the surface in the demo notebook:
temp_variation = neural_gcm_model.model_coords.horizontal.to_nodal(final_state.state.temperature_variation)
xarray.DataArray(temp_variation[-1, :, :], dims=['x', 'y']).plot.imshow(size=4, aspect=2, x='x', y='y')
There a few tricks for conversion (e.g., to handle units and reference offsets for temperature). We'll work on documenting these as part of https://github.com/google-research/neuralgcm/issues/11.
@shoyer we're most interested in the stochastic version, since we'll be running ensembles for data assimilation. What we need is to: 1) run the model ensemble for 9 hours, saving the trajectory state every 3 hours (or perhaps 1 hour). 2) extract the model state on sigma levels from the trajectory, use it compute the "observation equivalents" for observations valid between 3 and 9 hours. 3) use the model states (in grid space and also in observation space) to compute an update to the model state at 6 hours using the Ensemble Kalman Filter algorithm. 4) use these updated model states to re-initialize the model, and then rinse and repeat. You example looks like exactly what we need to do steps 1-3. For step 4, we need a way to re-encode the updated sigma-level states.
@shoyer we've got an initial version of a cycling EnKF for neuralgcm running and assimilating surface pressure observations. The results suggest that the orography that we are using for the forward observation operator (from aux_features['xarray_dataset']['geopotential_at_surface']
is perhaps not the actual orography the dycore uses (and is theerfore not consistent with the surface pressure in the model state). Does the dycore use a filtered version of that orography, and if so how can I access it?
@jswhit Indeed, for the dycore we use a learned filtered version of orography that is a bit smoother. You can extract it from the trained models:
import dataclasses
import functools
import pickle
import gcsfs
from dinosaur import spherical_harmonic
import haiku as hk
from neuralgcm import api
from neuralgcm import orographies
import numpy as np
import matplotlib.pyplot as plt
import xarray
gcs = gcsfs.GCSFileSystem(token='anon')
@hk.transform
def get_orography():
base_orography = functools.partial(
orographies.FilteredCustomOrography,
orography_data_path=None,
renaming_dict=dict(longitude='lon', latitude='lat'),
)
orography_coords = dataclasses.replace(
neural_gcm_model.model_coords,
horizontal=spherical_harmonic.Grid.with_wavenumbers(longitude_wavenumbers=126)
)
return orographies.LearnedOrography(
orography_coords,
neural_gcm_model._structure.specs.dt,
neural_gcm_model._structure.specs.physics_specs,
neural_gcm_model._structure.specs.aux_features,
base_orography_module=base_orography,
correction_scale=1e-5,
)()
model_name = 'neural_gcm_stochastic_1_4_deg_v0.pkl'
with gcs.open(f'gs://gresearch/neuralgcm/03_04_2024/{model_name}', 'rb') as f:
ckpt = pickle.load(f)
neural_gcm_model = api.PressureLevelModel.from_checkpoint(ckpt)
dycore_coords = spherical_harmonic.Grid.with_wavenumbers(longitude_wavenumbers=126)
learned_correction = neural_gcm_model.params[
'stochastic_modular_step_model/~/stochastic_physics_parameterization_step/'
'~/custom_coords_corrector/~/dycore_with_physics_corrector/~/learned_orography'
]['orography']
learned_orography_modal = get_orography.apply(
{'learned_orography': {'orography': learned_correction}}, rng=None
)
learned_orography = neural_gcm_model.from_nondim_units(
dycore_coords.to_nodal(learned_orography_modal), units='meters'
)
default_orography_modal = get_orography.apply(
{'learned_orography': {'orography': np.zeros_like(learned_correction)}}, rng=None
)
default_orography = neural_gcm_model.from_nondim_units(
dycore_coords.to_nodal(default_orography_modal), units='meters'
)
# learned orograpy
xarray.DataArray(learned_orography, dims=['x', 'y']).plot.imshow(x='x', y='y', aspect=2, size=4, robust=True, vmin=0)
# default orography
xarray.DataArray(default_orography, dims=['x', 'y']).plot.imshow(x='x', y='y', aspect=2, size=4, robust=True, vmin=0)
Thanks @shoyer, using the learned orography really moved the needle for the data assimilation.
For future reference, details of how to work with our model on sigma coordinates are now described in the NeuralGCM documentation: https://neuralgcm.readthedocs.io/en/latest/trained_models.html
@shoyer for the deterministic 0.7 degree model it looks like the learned orography is in stochastic_modular_step_model/~/dimensional_learned_weatherbench_to_primitive_with_memory_encoder/~/learned_weatherbench_to_primitive_encoder_1/~/learned_orography
and correction_scale
should be 2.e-6. Is this correct?
@jswhit good catch! Yes, we did switch to using slightly different values for that resolution.
Probably worth noting that in the model there are up to 3 different orography values (1 in encoder, 1 for the simulation and 1 in decoder).
I believe that for the model that @shoyer showed has fixed orography (modal representation of conservatively regridded ERA5 orography) used by the encoder and decoder and learned orography in the advance step (extracted by @shoyer's snipped)
In the 0.7 degree model all three components were learned. In a few experiments we found that in this setting decoder tends to keep more sharp features in the orography compared to the encoder. The variables that you've tracked down are looking at the encoder orography. Please let me know if the other one under "dycore_with_physics_corrector " doesn't show up.
We will also prioritize easier access to these in the new API that we are starting to work on.
Thanks @kochkov92, I was able to find 'stochastic_modular_step_model/~/stochastic_physics_parameterization_step/~/custom_coords_corrector/~/dycore_with_physics_corrector/~/learned_orography'
. However, when I use it with the get_orography
function @shoyer provided
@hk.transform
def get_orography():
base_orography = functools.partial(
orographies.FilteredCustomOrography,
orography_data_path=None,
renaming_dict=dict(longitude='lon', latitude='lat'),
)
orography_coords = dataclasses.replace(
neural_gcm_model.model_coords,
horizontal=spherical_harmonic.Grid.with_wavenumbers(longitude_wavenumbers=256)
)
return orographies.LearnedOrography(
orography_coords,
neural_gcm_model._structure.specs.dt,
neural_gcm_model._structure.specs.physics_specs,
neural_gcm_model._structure.specs.aux_features,
base_orography_module=base_orography,
correction_scale=2.e-6,
orogpath='stochastic_modular_step_model/~/stochastic_physics_parameterization_step/~/custom_coords_corrector/~/dycore_with_physics_corrector/~/learned_orography'
learned_correction = neural_gcm_model.params[orogpath]['orography']
learned_orography_modal = get_orography.apply(
{'learned_orography': {'orography': learned_correction}}, rng=None
I get
ValueError: 'learned_orography/orography' with retrieved shape (65023,) does not match shape=(66047,) dtype=<class 'jax.numpy.float32'>
please disregard my previous question - it works if I use longitude_wavenumbers=254
In order to use neuralgcm for data assimilation, we need access to the model variables on sigma levels (plus surface pressure). The
api.PressureLevelModel
provides access to decoded data on pressure levels. I'd like to add the ability to extract sigma-level data and surface pressure from a prediction, then restart a model prediction using updated sigma-level/surface pressure fields. How would you recommend doing this with the public API?@nniraj123 @frolovsa