benedekrozemberczki / pytorch_geometric_temporal

PyTorch Geometric Temporal: Spatiotemporal Signal Processing with Neural Machine Learning Models (CIKM 2021)
MIT License
2.61k stars 367 forks source link

StaticGraphTemporalSignalBatch does not provide batch #227

Closed beppe2hd closed 1 year ago

beppe2hd commented 1 year ago

Hi, I'm currently working with the code provided in https://github.com/Ditsuhi/GNN_Air_Quality and I'm experiencing some problem in understanding the idea behind StaticGraphTemporalSignalBatch

In a nutshell I'm setting up the dataloader in the following way:

class MadridDatasetLoader(object):

    def __init__(self, data_norm, edges, edge_weights, batch):
        super(MadridDatasetLoader, self).__init__()

        self.data_norm = data_norm
        self.edges = edges
        self.edge_weights = edge_weights
        self.batch = batch

    def _generate_task(self, num_timesteps_in: int = 6, num_timesteps_out: int = 6):

        time_steps_starter = 0  # it can be assigned as one of the following {0, 12, 24, 36}
        indices = [
            (i, i + time_steps_starter + (num_timesteps_in + num_timesteps_out))
            for i in range(self.data_norm.shape[2] - (time_steps_starter + num_timesteps_in + num_timesteps_out) + 1)
        ]
        #print(indices)
        # Generate observations
        features, target = [], []
        for i, j in indices:
            features.append((self.data_norm[:, :, i: i + num_timesteps_in]).numpy())
            target.append((self.data_norm[:, 0, i + num_timesteps_in + time_steps_starter: j]).numpy())

        self.features = features
        self.targets = target

    def get_dataset(
            self, num_timesteps_in: int = 6, num_timesteps_out: int = 6
    ) -> StaticGraphTemporalSignalBatch:

        self._generate_task(num_timesteps_in, num_timesteps_out)
        dataset = StaticGraphTemporalSignalBatch(
            self.edges, self.edge_weights, self.features, self.targets, self.batch)

        return dataset

And this is the portion of the code devoted to the training phase.

for epoch in range(5):
    loss = 0
    step = 0
    testlen = 0
    for snapshot in dataset: #tqdm(dataset, total=dataset.__sizeof__(), leave=True):
        testlen+=1
        snapshot = snapshot.to(device)
        # Get model predictions
        y_hat = model(snapshot.x, snapshot.edge_index, snapshot.edge_attr)
        # Mean Squared Error
        loss = loss + torch.mean((y_hat - snapshot.y) ** 2)
        step += 1
        #DEBUG POINT

    loss = loss / (step + 1)
    loss.backward()
    optimizer.step()
    optimizer.zero_grad()

Anyway, going through a debug session it seems that the returned snapshot is not a batch of 64 input elements but just a single instance of my input Indeed also the returned y_hatis just a single instance and in the end the number of iteration in an epoch correspond to the instances of the dataset.

snapshot
Out[4]: DataBatch(x=[24, 18, 12], edge_index=[2, 552], edge_attr=[552], y=[24, 12], batch=[64])
type(snapshot)
Out[5]: torch_geometric.data.batch.DataBatch
y_hat.shape
Out[6]: torch.Size([24, 12])
snapshot.x.shape
Out[7]: torch.Size([24, 18, 12])

In order to understand better how does it works I tried to use StaticGraphTemporalSignal instead StaticGraphTemporalSignalBatch of but it seems that nothing changes. The situation is the same

Out[1]: Data(x=[24, 18, 12], edge_index=[2, 552], edge_attr=[552], y=[24, 12])
type(snapshot)
Out[2]: torch_geometric.data.data.Data
y_hat.shape
Out[3]: torch.Size([24, 12])
snapshot.x.shape
Out[4]: torch.Size([24, 18, 12])

Finally my question is: what I'm doing wrong? How can I exploit the batches in the right way?