benedekrozemberczki / pytorch_geometric_temporal

PyTorch Geometric Temporal: Spatiotemporal Signal Processing with Neural Machine Learning Models (CIKM 2021)
MIT License
2.61k stars 367 forks source link

How to format `edge_indices` for DynamicGraphTemporalSignalBatch? #201

Closed LaurentBerder closed 1 year ago

LaurentBerder commented 1 year ago

Hi,

I'm trying to implement a model for a Dynamic Spatio-Temporal Graph, which therefore does not have the same amount of edges for each timestamps.

Therefore, though I'm able to create tensors of regular dimensions for edge_weights, features and targets, I can't for edge_indices, and I fall back to having a list of lists.

Here are the dimensions that I have:

I'm still able to create the DynamicGraphTemporalSignalBatch object without error as below:

from torch_geometric_temporal.signal import DynamicGraphTemporalSignalBatch

dataset = DynamicGraphTemporalSignalBatch(edge_indices=edge_indices,
                                          edge_weights=edge_weights,
                                          features=features,
                                          targets=targets,
                                          batches=batches)

But I can't use it afterwards. Even trying to read it returns an error:

dataset[0]
---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
[<ipython-input-294-55a9d71005c6>](https://localhost:8080/#) in <module>
----> 1 dataset[0]

1 frames
[/usr/local/lib/python3.7/dist-packages/torch_geometric_temporal/signal/dynamic_graph_temporal_signal_batch.py](https://localhost:8080/#) in __getitem__(self, time_index)
    132         else:
    133             x = self._get_feature(time_index)
--> 134             edge_index = self._get_edge_index(time_index)
    135             edge_weight = self._get_edge_weight(time_index)
    136             batch = self._get_batch_index(time_index)

[/usr/local/lib/python3.7/dist-packages/torch_geometric_temporal/signal/dynamic_graph_temporal_signal_batch.py](https://localhost:8080/#) in _get_edge_index(self, time_index)
     77             return self.edge_indices[time_index]
     78         else:
---> 79             return torch.LongTensor(self.edge_indices[time_index])
     80 
     81     def _get_batch_index(self, time_index: int):

ValueError: expected sequence of length 6 at dim 2 (got 4)

Any indication of what I'm doing wrong?

benedekrozemberczki commented 1 year ago

Based on this I cannot tell.

LaurentBerder commented 1 year ago

Thanks @benedekrozemberczki.

That's a bummer. But the DynamicGraphTemporalSignalBatch class is designed for this kind of structure, right?

What kind of extra details could I provide so that you can have a better idea of it?

Or maybe you could tell me what kind of shape DynamicGraphTemporalSignalBatch expects for edge_indices of a graph that doesn't always have the same amount of edges (that's how I interpret the "dynamic" part of the graph).