benedekrozemberczki / pytorch_geometric_temporal

PyTorch Geometric Temporal: Spatiotemporal Signal Processing with Neural Machine Learning Models (CIKM 2021)
MIT License
2.69k stars 376 forks source link

Number of channels does not change model performance substantially #104

Closed MoRoBe-Work closed 2 years ago

MoRoBe-Work commented 3 years ago

Dear Benedek (and whoever else may read this and have an idea what's going wrong,

I have been trying to use your implementation of the A3T-GCN model for traffic prediction as in the original paper for several months now. Although there is a certain success, there are some questions I was not able to answer through analysis of the code and it's execution. Therefore I hope you might be able to give me a hint on what I'm missing. The main concern is as follows:

In addition to the A3T-GCN I also use your GMAN Implementation. The GMAN model outperforms the A3T-GCN model consistently and significantly. So far so good, it might be just better suited to our data. However, there is one thing I really don't get. In

https://github.com/benedekrozemberczki/pytorch_geometric_temporal/blob/b8fab67b5745ad201c18c6403430a14f4cb5db94/examples/recurrent/a3tgcn_example.py#L19

you assign a number of out channels larger than the number of target values to the A3T-GCN. It is perfectly clear to me why this does make sense to increase the number of trainable parameters and allow for different filters for different situations in the input data. If the number of output channels is set to low, the number of trainable parameters drops to values below 1k. This obviously hurts the representational capabilities of the network. Thus, I designed our own model as follows:

self.a3tgcn = A3TGCN(in_channels=channels, out_channels=projection_factor * periods_out, periods=periods_in, improved=improved, cached=cached, add_self_loops=add_self_loops)

self.preprocessing_layer = Identity()

self.postprocessing_layer = Linear(projection_factor * periods_out, periods_out)

As you can see, the "preprocessing layer" is only a placeholder, I tried some stuff there but to no avail. Improved, cached and add_self_loops are left at their default values, only the possibility to change them is included. The number of in_channels is governed by the number of features per timestep and node, its either 1, 2 or 4, mostly 2. Four periods are used as input and four shall be returned. The projection factor is then used to control the actual size of the model. The postprocessing_layer subsequently projects the output of the A3T-GCN onto the number of desired outputs, just like in the chickenpox example.

However, the networks performance after training is pretty much the same independent of this size. I get RMSE's of about 12 for all model sizes while the GMAN is around 6. I analyzed the network outputs and it's not just the errors, but the actual outputs of the network that remain independent of size. In fact, both networks seem to somewhat reliably predict the average speed for most roads in the network, barely adapting to the current situation.

Does anyone especially @benedekrozemberczki as you wrote the code, have an idea why there might be so little reaction to the input data in each time step? I will not look into this during the weekend, but next week I will be perfectly happy to run any test you suggest and provide implementations and data as far as possible. Thanks in advance to anyone willing to even think about this.

Best Regards, MoRoBe

benedekrozemberczki commented 3 years ago

What is your exact architecture? The example does not have a proper hidden state update!

class RecurrentGCN(torch.nn.Module):
    def __init__(self, node_features):
        super(RecurrentGCN, self).__init__()
        self.recurrent = A3TGCN(node_features, 32, 1)
        self.linear = torch.nn.Linear(32, 1)

    def forward(self, x, edge_index, edge_weight, h):

        h = self.recurrent(x.view(x.shape[0],x.shape[1],1), edge_index, edge_weight, h)
        h_out = F.relu(h)
        y = self.linear(h_out)
        return y, h

At the start of each epoch H is set to be None.

elmahyai commented 2 years ago

Making the attention weights trainable improved the training a lot for me. #122

https://www.kaggle.com/elmahy/a3t-gcn-for-traffic-forecasting