openclimatefix / power_perceiver

Machine learning experiments using the Perceiver IO model to forecast the electricity system (starting with solar)
MIT License
7 stars 1 forks source link

Prepared Batches after processing seem to have differing numbers of timesteps, either 32 or 31 #195

Closed jacobbieker closed 1 year ago

jacobbieker commented 1 year ago

Describe the bug

Some of the inputs seem to have differing numbers of timesteps, either 32 or 31 when they are being put through the FullModel, resulting in a concatenation errors on the gsp_query_generator. The first dimension, which should be exxamples * times ends up being 992 (32 x 31) for some of them and 1024 (32 x 32) for others

To Reproduce

import glob
import os
from pathlib import Path

import numpy as np

from power_perceiver.load_prepared_batches.prepared_dataset import PreparedDataset
from power_perceiver.load_prepared_batches.data_sources import PV, GSP, HRVSatellite, NWP
from power_perceiver.production.model import FullModel
import pandas as pd
from torch.utils.data import DataLoader

dataset = PreparedDataset([PV(history_duration=pd.Timedelta("90 min"),),
                           GSP(history_duration=pd.Timedelta("2 hours"),),
                           HRVSatellite(history_duration=pd.Timedelta("30 min"),),
                           NWP(history_duration=pd.Timedelta("1 hour"),)],
                          data_path=Path("/home/jacob/Development/power_perceiver/data_for_testing/"))
print(dataset)
dataloader = DataLoader(dataset, num_workers=0, batch_size=None)
batch = next(iter(dataloader))
model = FullModel().eval()
model(batch)

This does require a slightly modified PreparedDataset to work past the #194 bug, so swap out that PreparedDataset with the one below. This version simply adds zeros for the topographic height, rather than trying to compute it with the Topographic processor.

@dataclass
class PreparedDataset(torch.utils.data.Dataset):
    """Load batches pre-prepared by `nowcasting_dataset`.

    Initialisation arguments:
        data_loaders: A list of instantiated data loader objects.
        data_path: Base path to the pre-prepared dataset. e.g. /path/to/v15/train/
        max_n_batches_per_epoch: If the user sets this to an int then
            this int will be the max number of batches used per epoch. If left as None
            then will load as many batches as are available.
        xr_batch_processors: Functions which takes an XarrayBatch,
            and does processing *across* modalities, and returns the processed XarrayBatch.
            Note that and processing *within* a modality should be done in
            PreparedDataSource.to_numpy.
        np_batch_processors: Functions which takes a NumpyBatch,
            and does processing *across* modalities, and returns the processed NumpyBatch.
            Note that and processing *within* a modality should be done in
            PreparedDataSource.to_numpy.

    Attributes:
        n_batches: int. Set by _set_number_of_batches.
    """

    data_loaders: Iterable[PreparedDataSource]
    data_path: Optional[Path] = None
    max_n_batches_per_epoch: Optional[int] = None
    xr_batch_processors: Optional[Iterable[Callable]] = None
    np_batch_processors: Optional[Iterable[Callable]] = None
    topography_location: str = str(
        "/home/jacob/Development/power_perceiver/data/europe_dem_2km_osgb.tif"
    )

    def __post_init__(self):
        # Sanity checks
        assert self.data_path.exists()
        assert len(self.data_loaders) > 0
        # Prepare PreparedDataSources.
        self._set_data_path_in_data_loaders()
        self._set_number_of_batches()
        np_batch_processors = [AlignGSPTo5Min(), EncodeSpaceTime(), SaveT0Time()]
        for modality_name in ["hrvsatellite", "gsp", "gsp_5_min", "pv", "nwp_target_time"]:
            np_batch_processors.append(SunPosition(modality_name=modality_name))
        #np_batch_processors.append(Topography(self.topography_location))
        self.np_batch_processors = np_batch_processors
        super().__init__()

    def _set_data_path_in_data_loaders(self) -> None:
        for data_loader in self.data_loaders:
            data_loader.data_path = self.data_path

    def _set_number_of_batches(self) -> None:
        """Set number of batches.  Check every data source."""
        self.n_batches = None
        for data_loader in self.data_loaders:
            n_batches_for_data_source = data_loader.get_n_batches_available()
            if self.n_batches is None:
                self.n_batches = n_batches_for_data_source
            elif n_batches_for_data_source != self.n_batches:
                self.n_batches = min(self.n_batches, n_batches_for_data_source)
                _log.warning(
                    f"Warning! {data_loader} has a different number of batches to at"
                    " least one other modality!"
                    f" We'll use the minimum of the two values: {self.n_batches}"
                )
        if self.max_n_batches_per_epoch is not None:
            self.n_batches = min(self.n_batches, self.max_n_batches_per_epoch)
        assert self.n_batches is not None
        assert self.n_batches > 0

    def __len__(self) -> int:
        return self.n_batches

    def __getitem__(self, batch_idx: int) -> NumpyBatch:
        if batch_idx >= self.n_batches:
            raise KeyError(f"{batch_idx=} is out of bounds! {self.n_batches=}")
        xr_batch = self._get_xarray_batch(batch_idx)
        xr_batch = self._process_xr_batch(xr_batch)
        np_batch = self._xarray_to_numpy_batch(xr_batch)
        del xr_batch
        np_batch = self._process_np_batch(np_batch)
        np_batch[BatchKey.hrvsatellite_surface_height] = np.zeros_like(np_batch[BatchKey.hrvsatellite_actual][:,0,0,:])
        return np_batch

    def _get_xarray_batch(self, batch_idx: int) -> NumpyBatch:
        """Load the completely un-modified batches from disk and store them in a dict."""
        xr_batch: NumpyBatch = {}
        for data_loader in self.data_loaders:
            xr_data_for_data_source = data_loader[batch_idx]
            xr_batch[data_loader.__class__] = xr_data_for_data_source
        return xr_batch

    def _process_xr_batch(self, xr_batch: XarrayBatch) -> XarrayBatch:
        """If necessary, do any processing which needs to be done across modalities,
        on the xr.Datasets."""
        if self.xr_batch_processors:
            for xr_batch_processor in self.xr_batch_processors:
                xr_batch = xr_batch_processor(xr_batch)
        return xr_batch

    def _xarray_to_numpy_batch(self, xr_batch: XarrayBatch) -> NumpyBatch:
        """Convert from xarray Datasets to numpy."""
        np_batch: NumpyBatch = {}
        for data_loader_class, xr_dataset in xr_batch.items():
            if data_loader_class == BatchKey.requested_timesteps:
                # `ReduceNumTimesteps` introduces a `requested_timesteps` key,
                # whose value is a np.ndarray.
                requested_timesteps = xr_dataset
                np_batch[BatchKey.requested_timesteps] = requested_timesteps
            else:
                np_data_for_data_source = data_loader_class.to_numpy(xr_dataset)
                np_batch.update(np_data_for_data_source)
        return np_batch

    def _process_np_batch(self, np_batch: NumpyBatch) -> NumpyBatch:
        """If necessary, do any processing which needs to be done across modalities,
        on the NumpyBatch."""
        if self.np_batch_processors:
            for np_batch_processor in self.np_batch_processors:
                np_batch = np_batch_processor(np_batch)
        return np_batch

Expected behavior A clear and concise description of what you expected to happen.

Additional context Add any other context about the problem here.

JackKelly commented 1 year ago

Hey @jacobbieker, would you like me to try to help fix this? Or is this issue no longer relevant because we're hoping to use data pipes instead of the old power_perceiver.load_prepared_batches code?

jacobbieker commented 1 year ago

Hi, I think we can probably leave this and just focus on the datapipes for it now, as this whole issue I think stems from the prepared batches directly and how they are set up.