google-deepmind / graphcast

Apache License 2.0
4.36k stars 537 forks source link

Are forcing variables repeated? #52

Closed Javier-Jimenez99 closed 5 months ago

Javier-Jimenez99 commented 5 months ago

In graphcast.py are Tasks defined. For example, the task used for 1 degree resolution and 13 pressure levels have these varaibles:

input_variables=(
        '2m_temperature', 
        'mean_sea_level_pressure', 
        '10m_v_component_of_wind', 
        '10m_u_component_of_wind', 
        'total_precipitation_6hr', 
        'temperature', 
        'geopotential', 
        'u_component_of_wind', 
        'v_component_of_wind'
        'vertical_velocity'
        'specific_humidity'
        'toa_incident_solar_radiation'
        'year_progress_sin'
        'year_progress_cos'
        'day_progress_sin'
        'day_progress_cos'
        'geopotential_at_surface'
        'land_sea_mask'
    )
    forcing_variables=(
        'toa_incident_solar_radiation'
        'year_progress_sin'
        'year_progress_cos'
        'day_progress_sin'
        'day_progress_cos'
    )

The forcing variables are included inside the input ones. However, inside graphcast the function that transform xarrays to numpy is _inputs_to_grid_node_features:

def _inputs_to_grid_node_features(
        self,
        inputs: xarray.Dataset,
        forcings: xarray.Dataset,
    ) -> chex.Array:
        """xarrays -> [num_grid_nodes, batch, num_channels]."""

        # xarray `Dataset` (batch, time, lat, lon, level, multiple vars)
        # to xarray `DataArray` (batch, lat, lon, channels)
        stacked_inputs = model_utils.dataset_to_stacked(inputs)
        stacked_forcings = model_utils.dataset_to_stacked(forcings)
        stacked_inputs = xarray.concat(
            [stacked_inputs, stacked_forcings], dim="channels"
        )

        # xarray `DataArray` (batch, lat, lon, channels)
        # to single numpy array with shape [lat_lon_node, batch, channels]
        grid_xarray_lat_lon_leading = model_utils.lat_lon_to_leading_axes(
            stacked_inputs
        )
        return xarray_jax.unwrap(grid_xarray_lat_lon_leading.data).reshape(
            (-1,) + grid_xarray_lat_lon_leading.data.shape[2:]
        )

The question is: why are input and forcing varaibles being concatenated, if forcing are already included in input?

alvarosg commented 5 months ago

Teh difference between forcings and inputs mostly have to do with the time associated with the features. inputs contain time data associated with past and present times. forcings and targets contain time data associated with future times.

Having a variable both in inputs and in forcings means you are going to condition your model both on present/past values of that variable, as well as on future values of that variable. Conditioning on a future value is possible because every forcing variable is something that can be computed analytically for any future date.

Additionally, to be able to rollout the model autoregressively, every single variable in inputs, needs to be either static, be a target, or be a forcing, because as you advance the prediction by one step, then the future timesteps become the present timesteps, so anything that is not predicted for the future, it needs to be available as a forcing. I you look at rollout.py, you will see how at each step as time shifts, some forcings variables are moved to the inputs variables.

Javier-Jimenez99 commented 5 months ago

Ah okay now I understand. The variables in inputs are from the previous state (t-1) and the present (t), and forcing are from the future (t+1), so they are not repeated.

I have another question. Using the same model after concatenating inputs and forcing I get a numpy array of size num_points x batch x 183, but in params the weight of the embedding layer have size 186 x 512. What are these 3 features and where are they added?

Thank you so much!