pathpy / pathpyG

GPU-accelerated Next-Generation Network Analytics and Graph Learning for Time Series Data on Complex Networks.
https://www.pathpy.net
GNU Affero General Public License v3.0
31 stars 1 forks source link

`TemporalGraph.get_window(...)` and `TemporalGraph.get_snapshot` behaviour not as expected #160

Open M-Lampert opened 5 months ago

M-Lampert commented 5 months ago

As I understand both methods, get_snapshot should return a TemporalGraph that contains all edges from the original graph that occur in the time window from start to end. This means that the returned edge index can potentially have very different sizes. See the following example for the expected behaviour:

>>> t = TemporalGraph(TemporalData(
>>>    src=[0,0,1,0],
>>>    dst=[1,2,2,1],
>>>    t=[1,1,1,2]))
>>> t.get_snapshot(start=1, end=2)
TemporalGraph containing the first 3 edges

>>> t.get_snapshot(start=2, end=3)
TemporalGraph containing only the last edge

get_window(...) instead should return a window of fixed size end-start that contains the edges starting at index start and ending (non-inclusive) at index end in an edge index sorted by time. Thus, e.g. get_window(0,5) would return the first 5 events. The current implementation of get_window does exactly this but get_snapshot(...) also does this.

ALSO: For get_window to consistently work, we have to ensure to keep an edge_index sorted by time. This is not the case after applying shuffle_time in its current implementation.

M-Lampert commented 5 months ago

I already wrote a code snippet to efficently (in terms of runtime) index all edges in a specific time window (i.e. for get_snapshot in my understanding):

data_min_t = data.t.min().item()
data.t = data.t - data_min_t
unique_t, t_counts = data.t.unique(return_counts=True)
# For each consecutive pair of timestamps, we need to know how many timestamps are missing in between
missing_steps = unique_t[1:] - unique_t[:-1]
# Create a pointer that you can index with each timestamp and points to the position in the edge_index where this specific timestamp starts
data.ptr = torch.repeat_interleave(
    cumsum(t_counts),
    torch.cat(
        [
            torch.ones(1, dtype=torch.int, device=data.t.device),
            missing_steps,
            torch.ones(1, dtype=torch.int, device=data.t.device),
        ]
    ),
)

Here I assume that data is a TemporalData-object from PyG and the edge index is sorted by time. I also assume the timestamps to be integers and remap them to start at 0 for simplicity. We could also start at any time stamp but this could potentially waste a lot of memory since we would save a ptr for potentially many timestamps that never occur.

With the code above, get_snapshot could look as follows:

def gets_snapshot(start, end):
   return data[data.ptr[start]:data.ptr[end]]

We could also (if we do not want to reindex but still start at the minimum) do this:

def gets_snapshot(start, end):
   return data[data.ptr[start-data_min_t]:data.ptr[end-data_min_t]]

By trading of runtime for memory efficiency, we could also search for the correct pointers in the sorted timestamp tensor t. I think if only used once, this is probably preferable, but if we want to use this as an iterator with a rolling time window, the first approach could save a lot of time for large datasets.

M-Lampert commented 5 months ago

So there is a PyG method TemporalData.snapshot(...) which should work the same way as our intended get_snapshot(...), but it was not working for me when I tried it. Now I found the reason why: https://github.com/pyg-team/pytorch_geometric/issues/3230 I.e. snapshot(...) is implemented in Data and not TemporalData. TemporalData will not be supported for much longer and will be deprecated in the future. This is also the reason why sort_by_time and other time-related methods did not work for me before when tested with TemporalData because TemporalData just inherited it from the implementation in Data. Long Story Short: We can do something like this:

def gets_snapshot(start, end):
    return data.snapshot(start, end) # data needs to be a PyG Data object, not a TemporalData object!