TorchSpatiotemporal / tsl

tsl: a PyTorch library for processing spatiotemporal data.
https://torch-spatiotemporal.readthedocs.io/
MIT License
255 stars 24 forks source link

Question: What is the best way to use multiple SpatioTemporalDatasets? #42

Open Anjum48 opened 1 month ago

Anjum48 commented 1 month ago

If you have many SpatioTemporalDatasets, e.g. experiments with different numbers of sensors (i.e. nodes), what would be the best way to combining them to make a single dataset?

Or is it preferred to create a list of Data objects for each dataset above, concatenate them (i.e. data_list.extend(more_data)) and iterate using DisjointBatch?

Another option could be to create a ConcatDataset but I think this breaks the batching, but does create a list of Data objects as I described above

Anjum48 commented 3 weeks ago

Thought I might answer my own question now that I've got something that works.

It's possible to concatenate multiple SpatioTemporalDataset objects using something like this (I have each experiment as it's own .pt file, which has each graph stored as a PyTorch Geometric Data object):

from torch_geometric.data import Dataset
from tsl.data import ImputationDataset, SpatioTemporalDataset

class ForecastingDataset(Dataset):
    def __init__(
        self,
        metadata_df: pd.DataFrame,
        window: int = 12,
        horizon: int = 1,
        delay: int = 0,
        stride: int = 1,
        window_lag: int = 1,
        horizon_lag: int = 1,
        radius: float = 1.0,
        transform=None,
        pre_transform=None,
        pre_filter=None,
    ):
        super().__init__(transform, pre_transform, pre_filter)
        self.df = metadata_df.reset_index(drop=True)
        self.trfm = Compose(
            [
                RadiusGraph(r=radius),
                calculate_inverse_distance,
            ]
        )
        self.data = None

        for index, row in self.df.iterrows():
            data = torch.load(INPUT_PATH / "time" / row["file_name"])
            data = self.trfm(data)

            tsl_data = SpatioTemporalDataset(
                target=data.x.T,
                connectivity=(data.edge_index, data.edge_attr),
                window=window,
                horizon=horizon,
                delay=delay,
                stride=stride,
                window_lag=window_lag,
                horizon_lag=horizon_lag,
            )

            if self.data is None:
                self.data = tsl_data
            else:
                self.data += tsl_data

    def len(self):
        return len(self.data)

    def get(self, idx):
        return self.data[idx]

There's probably a better way of doing the above, but my data is currently small enough to fit in memory.

Then to create a Lightning DataModule, you can use DisjointGraphLoader which is not currently documented (I found it by exploring the code). This returns DisjointBatch objects:

import lightning as pl
from tsl.data.loader import DisjointGraphLoader

class STDataModule(pl.LightningDataModule):
    def __init__(
        self,
        window: int = 12,
        horizon: int = 1,
        delay: int = 0,
        stride: int = 1,
        window_lag: int = 1,
        horizon_lag: int = 1,
        radius: float = 1.0,
        batch_size: int = 32,
        seed: int = 48,
        n_folds: int = 5,
        num_workers: int = 1,
        **kwargs,
    ):
        super().__init__()
        self.dataset_kwargs = {
            "window": window,
            "horizon": horizon,
            "delay": delay,
            "stride": stride,
            "window_lag": window_lag,
            "horizon_lag": horizon_lag,
            "radius": radius,
        }
        self.batch_size = batch_size
        self.seed = seed
        self.n_folds = n_folds
        self.num_workers = num_workers
        self.train_steps = 0

    def setup(self, stage=None, fold: int = 0):
        if stage == "fit" or stage == "predict":
            folder = INPUT_PATH / "time"

            metadata = pd.read_csv(folder / "metadata.csv")

            if "fold" not in metadata.columns:
                metadata = create_folds(
                    metadata,
                    n_splits=self.n_folds,
                    random_state=self.seed,
                    group="run_id",
                )

            trn_df = metadata.query(f"fold != {fold}")
            val_df = metadata.query(f"fold == {fold}")

            self.train_ds = ForecastingDataset(trn_df, **self.dataset_kwargs)
            self.valid_ds = ForecastingDataset(val_df, **self.dataset_kwargs)
            self.pred_ds = ForecastingDataset(val_df, **self.dataset_kwargs)

            self.train_steps = len(self.train_ds) / self.batch_size
            print(
                f"Fold {fold}:",
                len(self.train_ds),
                "train and",
                len(self.valid_ds),
                "valid samples",
            )

    def train_dataloader(self):
        return DisjointGraphLoader(
            self.train_ds,
            num_workers=self.num_workers,
            batch_size=self.batch_size,
            shuffle=True,
            drop_last=True,
            pin_memory=True,
            persistent_workers=True,
            force_batch=True,
        )

    def val_dataloader(self):
        return DisjointGraphLoader(
            self.valid_ds,
            batch_size=self.batch_size,
            num_workers=self.num_workers,
            pin_memory=True,
            persistent_workers=True,
            force_batch=True,
        )

    def predict_dataloader(self):
        return DisjointGraphLoader(
            self.pred_ds,
            batch_size=self.batch_size,
            num_workers=1,
            pin_memory=True,
            shuffle=False,
            force_batch=True,
        )