google-research / neuralgcm

Hybrid ML + physics model of the Earth's atmosphere
https://neuralgcm.readthedocs.io
Apache License 2.0
705 stars 82 forks source link

How to convert sigma level vertical velocity to pressure level? #69

Open weatherforecasterwhai opened 6 months ago

weatherforecasterwhai commented 6 months ago

Thanks to @shoyer in #68 , I've got sigma level vertical velocity. So I try to use _model.model_coords.horizontal.to_nodal(vertical_velocity)_ as @shoyer said in https://github.com/google-research/neuralgcm/issues/8.

However it works with this error: size of label '1' for operand 1 (65) does not match previous terms (64). I've searched and can't fix it by myself. And as in #8, when running _model.model_coords.horizontal.to_nodal(advanced.state.temperature_variation)_ is ok. The full traceback is like this:

1715263857330

shoyer commented 6 months ago

I believe to_nodal and to_modal only operate on 2D arrays. If you want to transform a higher dimensional field, wrap the transformations in jax.vmap.

weatherforecasterwhai commented 6 months ago

Thanks to mention the dimension. 1.Then advanced.state.temperature_variation.shape is (32,128,65), and the vertical velocity shape is (32,128,64). So, maybe the to_nodal can operate on 3D array (32,128,65), but can't operate on (32,128,64), and the error lies between 65 and 64. More interesting thing is , the model.model_coords.nodal_shape is (32,128,64)! So, does it need one more grid to calculate from modal to_nodal? Running z=jnp.zeros((32,128,1)), and omega1=jnp.concatenate((omega,z),axis=2) to get (32,128,65) shape. It seems to work.

2.Running like _this xarray_utils.save_netcdf(xarray.DataArray(omega).todataset(name="omega"),"omega.nc") could save to a netcdf file, but it's not easy be read to Grads (it's a popular and easy ). While using neural_gcm_model.data_to_xarray(preditcions,times=times).to_netcdf('*.nc') is much better, its format is well defined.

Do you have a better way to 1.solve the dimension problem, 2. save to netcdf like neural_gcm_model.data_to_xarray? Thank you!

shoyer commented 6 months ago

1.Then advanced.state.temperature_variation.shape is (32,128,65), and the vertical velocity shape is (32,128,64). So, maybe the to_nodal can operate on 3D array (32,128,65), but can't operate on (32,128,64), and the error lies between 65 and 64.

Please double check, but I believe sigma level vertical velocity may already be defined in "nodal" space already (on the Gaussian grid), and hence does not need the to_nodal() transform.

2.Running like this xarray_utils.save_netcdf(xarray.DataArray(omega).to_dataset(name="omega"),"omega.nc") could save to a netcdf file, but it's not easy be read to Grads (it's a popular and easy ). While using neural_gcm_model.data_to_xarray(preditcions,times=times).to_netcdf('*.nc') is much better, its format is well defined.

You should use data_to_xarray(...).to_netcdf(). There is a lot of old stuff in xarray_utils that I would not recommend using, and we may remove to avoid confusion.

weatherforecasterwhai commented 6 months ago

@shoyer I don't understand the codes whether it outputs sigma level or pressure level. Please have a look at compute_vertical_velocity: def compute_vertical_velocity( state: State, coords: coordinate_systems.CoordinateSystem ) -> jax.Array: """Calculate vertical velocity at the center of each layer.""" sigma_dot_boundaries = compute_diagnostic_state(state, coords).sigma_dot_full assert sigma_dot_boundaries.ndim == 3 """This matches the default boundary conditions for vertical velocity from sigma_coordinates.centered_vertical_advection""" sigma_dot_padded = jnp.pad(sigma_dot_boundaries, [(1, 1), (0, 0), (0, 0)]) return 0.5 * (sigma_dot_padded[1:] + sigma_dot_padded[:-1])

shoyer commented 6 months ago

Vertical velocity (and everything inside the dynamical core) is output on sigma levels. If you want it on pressure levels, you can use the vertical interpolation routines from dinosaur.vertical_interpolation.

weatherforecasterwhai commented 6 months ago

Vertical velocity is really important to analyse weather, please make a easy api to output it in the same way as p_e or other variables. Thank you!

shoyer commented 6 months ago

Another option is to calculate vertical velcoity as a diagnosed variable from an output dataset on pressure levels with the WeatherBench2 code: https://weatherbench2.readthedocs.io/en/latest/_autosummary/weatherbench2.derived_variables.VerticalVelocity.html#weatherbench2.derived_variables.VerticalVelocity