Open Anjum48 opened 1 month ago
Thought I might answer my own question now that I've got something that works.
It's possible to concatenate multiple SpatioTemporalDataset
objects using something like this (I have each experiment as it's own .pt file, which has each graph stored as a PyTorch Geometric Data
object):
from torch_geometric.data import Dataset
from tsl.data import ImputationDataset, SpatioTemporalDataset
class ForecastingDataset(Dataset):
def __init__(
self,
metadata_df: pd.DataFrame,
window: int = 12,
horizon: int = 1,
delay: int = 0,
stride: int = 1,
window_lag: int = 1,
horizon_lag: int = 1,
radius: float = 1.0,
transform=None,
pre_transform=None,
pre_filter=None,
):
super().__init__(transform, pre_transform, pre_filter)
self.df = metadata_df.reset_index(drop=True)
self.trfm = Compose(
[
RadiusGraph(r=radius),
calculate_inverse_distance,
]
)
self.data = None
for index, row in self.df.iterrows():
data = torch.load(INPUT_PATH / "time" / row["file_name"])
data = self.trfm(data)
tsl_data = SpatioTemporalDataset(
target=data.x.T,
connectivity=(data.edge_index, data.edge_attr),
window=window,
horizon=horizon,
delay=delay,
stride=stride,
window_lag=window_lag,
horizon_lag=horizon_lag,
)
if self.data is None:
self.data = tsl_data
else:
self.data += tsl_data
def len(self):
return len(self.data)
def get(self, idx):
return self.data[idx]
There's probably a better way of doing the above, but my data is currently small enough to fit in memory.
Then to create a Lightning DataModule
, you can use DisjointGraphLoader
which is not currently documented (I found it by exploring the code). This returns DisjointBatch
objects:
import lightning as pl
from tsl.data.loader import DisjointGraphLoader
class STDataModule(pl.LightningDataModule):
def __init__(
self,
window: int = 12,
horizon: int = 1,
delay: int = 0,
stride: int = 1,
window_lag: int = 1,
horizon_lag: int = 1,
radius: float = 1.0,
batch_size: int = 32,
seed: int = 48,
n_folds: int = 5,
num_workers: int = 1,
**kwargs,
):
super().__init__()
self.dataset_kwargs = {
"window": window,
"horizon": horizon,
"delay": delay,
"stride": stride,
"window_lag": window_lag,
"horizon_lag": horizon_lag,
"radius": radius,
}
self.batch_size = batch_size
self.seed = seed
self.n_folds = n_folds
self.num_workers = num_workers
self.train_steps = 0
def setup(self, stage=None, fold: int = 0):
if stage == "fit" or stage == "predict":
folder = INPUT_PATH / "time"
metadata = pd.read_csv(folder / "metadata.csv")
if "fold" not in metadata.columns:
metadata = create_folds(
metadata,
n_splits=self.n_folds,
random_state=self.seed,
group="run_id",
)
trn_df = metadata.query(f"fold != {fold}")
val_df = metadata.query(f"fold == {fold}")
self.train_ds = ForecastingDataset(trn_df, **self.dataset_kwargs)
self.valid_ds = ForecastingDataset(val_df, **self.dataset_kwargs)
self.pred_ds = ForecastingDataset(val_df, **self.dataset_kwargs)
self.train_steps = len(self.train_ds) / self.batch_size
print(
f"Fold {fold}:",
len(self.train_ds),
"train and",
len(self.valid_ds),
"valid samples",
)
def train_dataloader(self):
return DisjointGraphLoader(
self.train_ds,
num_workers=self.num_workers,
batch_size=self.batch_size,
shuffle=True,
drop_last=True,
pin_memory=True,
persistent_workers=True,
force_batch=True,
)
def val_dataloader(self):
return DisjointGraphLoader(
self.valid_ds,
batch_size=self.batch_size,
num_workers=self.num_workers,
pin_memory=True,
persistent_workers=True,
force_batch=True,
)
def predict_dataloader(self):
return DisjointGraphLoader(
self.pred_ds,
batch_size=self.batch_size,
num_workers=1,
pin_memory=True,
shuffle=False,
force_batch=True,
)
If you have many SpatioTemporalDatasets, e.g. experiments with different numbers of sensors (i.e. nodes), what would be the best way to combining them to make a single dataset?
Or is it preferred to create a list of
Data
objects for each dataset above, concatenate them (i.e.data_list.extend(more_data)
) and iterate usingDisjointBatch
?Another option could be to create a
ConcatDataset
but I think this breaks the batching, but does create a list ofData
objects as I described above