NASA-IMPACT / Prithvi-WxC

Implementation of the Prithvi WxC Foundation Model and Downstream Tasks
MIT License
98 stars 13 forks source link

Missing Index:Variable Mapping in both In/Out #28

Open gajeshladhar opened 4 days ago

gajeshladhar commented 4 days ago

Hi team,

While working with the PrithviWxC model from this repository, I noticed that the output tensor has a shape of (N, 160, H, W), where the 160 dimensions lack any associated metadata for the variables.

For example, to extract temperature, I need to know its index (e.g., temp = out[0, 5, :, :] if temperature is at index 5). However, without a clear index-to-variable mapping, it’s difficult to correctly interpret the model’s outputs.

Could you please provide the list of 160 variables along with their corresponding index positions?

Best Regards, Gajesh

worldPower555 commented 3 days ago

I also have a similar issue. I've recently encountered some problems while using model for inference and need your assistance. I noticed that the model's input is based on the NetCDF format of MERRA-2 data with 160 variables, so I understand that the output should also have a similar data structure, with a spatial resolution of 0.5 degrees by 0.625 degrees.

After running the provided example ipynb, I observed that the output tensor shape is (1, 160, 360, 576). I would like to save these outputs in NetCDF format for further meteorological analysis. Could you recommend a method or steps to achieve this conversion? Specifically, I need to correctly map each variable to their corresponding latitude, longitude, and possible vertical levels, and ideally save the output results as NetCDF format files.

Thank you for your help, and I look forward to your reply!

gajeshladhar commented 3 days ago

@worldPower555 yes, correct. it will be better if model input & output both can be treated as xarray xr.Dataset (directly from netcdf files), the entire flow can be much easy to work with.

ankurk017 commented 3 days ago

@gajeshladhar @worldPower555 Thank you for using Prithvi-WxC. Here is the function which enables you to convert the model output into an xarray format. It will generate two separate xarray datasets: one for surface-level data and another for pressure-level data.

Example to run on random dataset: sfc_data, prs_data = to_xarray(np.random.rand(10, 160, 360, 576), initial_time='2023-01-01T00:00:00', )

import xarray as xr
import numpy as np
import pandas as pd
from datetime import datetime

def to_xarray(
    prediction,  
    initial_time:str='2000-01-01T00:00:00',
    freq: str='6H'
):
    """
    Convert WxC Prithvi prediction data to an xarray Dataset.

    This function takes a numpy array of prediction data and converts it into an xarray Dataset
    with appropriate dimensions, coordinates, and variable names. It also performs some basic
    assertions to ensure the input data has the expected shape.

    Parameters:
    -----------
    prediction : numpy.ndarray
        A 4D numpy array with dimensions (time, variables, latitude, longitude).
        Expected shape is (n_timesteps, 160, 360, 576).

    initial_time : str, optional
        The initial timestamp for the time dimension. Should be in a format that
        pandas.Timestamp can parse, e.g., 'YYYY-MM-DDTHH:MM:SS'.
        Default is '2000-01-01T00:00:00'.

    freq : str, optional
        The frequency of the time steps. This should be a pandas frequency string.
        Default is '6H' (6 hours).

    Returns:
    --------
    xarray.Dataset
        An xarray Dataset containing the prediction data with appropriate dimensions,
        coordinates, and variable names.

    Raises:
    -------
    AssertionError
        If the input prediction array does not have the expected shape.

    Warnings:
    ---------
    UserWarning
        If the default initial_time is used.

    Notes:
    ------
    - The function assumes a specific structure for the variables:
      - 20 surface variables followed by 14 levels of 10 vertical variables each.
    - Longitude ranges from -180 to 180 with 0.625 degree resolution.
    - Latitude ranges from -90 to 90 with 0.5 degree resolution.
    - The time dimension is created based on the initial_time and freq parameters.

    Example:
    --------
    >>> import numpy as np
    >>> prediction = np.random.rand(10, 160, 360, 576)
    >>> sfc_data, prs_data = to_xarray(prediction, initial_time='2023-01-01T00:00:00', freq='6H')
    >>> print(sfc_data)
    >>> print(prs_data)
    """
    assert prediction.shape[1] == 160, f"Expected 160 variables, but got {prediction.shape[1]}"
    assert prediction.shape[2] == 360, f"Expected 360 latitudes, but got {prediction.shape[2]}"
    assert prediction.shape[3] == 576, f"Expected 576 longitudes, but got {prediction.shape[3]}"
    import warnings

    if initial_time == '2000-01-01T00:00:00':
        warnings.warn("Setting default timestamp to 2000-01-01T00:00. If you want to use your own timestamp, \
                      please provide the initial_time argument in YYYY-MM-DDTHH:MM:SS format, \
                      or any format that pd.Timestamp accepts.", UserWarning)

    lon = np.arange(-180, 180, 0.625)
    lat = np.arange(-90, 90, 0.5)

    start_time = pd.Timestamp(initial_time)
    time_range = pd.date_range(start=start_time, periods=len(prediction), freq=freq)

    prediction_merged = np.stack(
        [prediction[i] for i in range(len(prediction))], axis=0
    )

    gt_data = xr.Dataset(
        {
            "prithvi": (
                ["time", "vars", "latitude", "longitude"],
                prediction_merged,
            ),
        },
        coords={
            "time": time_range,
            "vars": np.arange(0, 160),
            "latitude": lat,
            "longitude": lon,
        },
    )

    sfc_vars = [
        "EFLUX", "GWETROOT", "HFLUX", "LAI", "LWGAB", "LWGEM", 
        "LWTUP", "PS", "QV2M", "SLP", "SWGNT", "SWTNT", 
        "T2M", "TQI", "TQL", "TQV", "TS", "U10M", "V10M", "Z0M",
    ]

    vertical_vars = ["CLOUD", "H", "OMEGA", "PL", "QI", "QL", "QV", "T", "U", "V"]

    levels = [
        34.0, 39.0, 41.0, 43.0, 44.0, 45.0, 48.0, 51.0, 
        53.0, 56.0, 63.0, 68.0, 71.0, 72.0,
    ]

    nominal_pres_levels = [48, 109, 150, 208, 245, 288, 412, 525, 600, 700, 850, 925, 970, 985]

    gt_sfc_data = gt_data.isel(vars=np.arange(0, 20))
    gt_sfc_data["vars"] = sfc_vars

    gt_prs_data = gt_data.isel(vars=np.arange(20, 160))

    reshaped_data = gt_prs_data["prithvi"].values.reshape(-1, 10, 14, 360, 576)

    gt_prs_data = xr.Dataset(
        {
            "prithvi": (
                ("time", "variables", "levels", "latitude", "longitude"),
                reshaped_data,
            )
        },
        coords={
            "time": gt_prs_data["time"].values,
            "variables": vertical_vars,
            "nominal_pres": nominal_pres_levels,
            "latitude": gt_prs_data["latitude"].values,
            "longitude": gt_prs_data["longitude"].values,
        },
    )

    global_attrs = {
        "description": "This is generated from the WxC Prithvi model output on eta level.",
        "vertical": 'eta-level',
        "model": "WxC Prithvi",
        "creation_date": datetime.now().strftime("%Y-%m-%d %H:%M:%S")
    }

    gt_sfc_data.attrs.update(global_attrs)
    gt_prs_data.attrs.update(global_attrs)

    return gt_sfc_data, gt_prs_data