benedekrozemberczki / pytorch_geometric_temporal

PyTorch Geometric Temporal: Spatiotemporal Signal Processing with Neural Machine Learning Models (CIKM 2021)
MIT License
2.61k stars 367 forks source link

Why does the A3T-GCN model's predicted curve never match the real curve? #179

Closed Meow10 closed 1 year ago

Meow10 commented 2 years ago

The code for A3TGCN:

class TemporalGNN(torch.nn.Module):
    def __init__(self, node_features, periods):
        super(TemporalGNN, self).__init__()

        self.tgnn = A3TGCN(in_channels=node_features, 
                           out_channels=32, 
                           periods=periods)

        self.linear = torch.nn.Linear(32, periods)

    def forward(self, x, edge_index):
        """
        x = Node features for T time steps
        edge_index = Graph edge indices
        """
        h = self.tgnn(x, edge_index)
        h = F.relu(h)
        h = self.linear(h)
        return h

model = TemporalGNN(node_features=2, periods=12).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=args.learning_rate, weight_decay=args.weight_decay)
scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.9)

model.train()

minl = 1e5
print("Running training...")
for epoch in range(200): 
    loss = 0
    step = 0
    for snapshot in train_dataset:
        snapshot = snapshot.to(device)
        y_hat = model(snapshot.x, snapshot.edge_index)
        loss = loss + torch.mean((y_hat-snapshot.y)**2) 
        step += 1
        if step > subset:
          break

    loss = loss / (step + 1)
    loss.backward()
    optimizer.step()
    optimizer.zero_grad()

On the dataset METR-LA, the predicted curve is smooth and very different from the real curve. result1 After parameter tuning: result2

Doradx commented 1 year ago

Maybe I found the reason recently. https://github.com/benedekrozemberczki/pytorch_geometric_temporal/issues/204

aurorarossi commented 1 year ago

Hi! I recently wrote this tutorial about traffic prediction and the A3T-GCN model, maybe you can find it useful even though it is written in the Julia programming language.