Open M-Lampert opened 5 months ago
I already wrote a code snippet to efficently (in terms of runtime) index all edges in a specific time window (i.e. for get_snapshot
in my understanding):
data_min_t = data.t.min().item()
data.t = data.t - data_min_t
unique_t, t_counts = data.t.unique(return_counts=True)
# For each consecutive pair of timestamps, we need to know how many timestamps are missing in between
missing_steps = unique_t[1:] - unique_t[:-1]
# Create a pointer that you can index with each timestamp and points to the position in the edge_index where this specific timestamp starts
data.ptr = torch.repeat_interleave(
cumsum(t_counts),
torch.cat(
[
torch.ones(1, dtype=torch.int, device=data.t.device),
missing_steps,
torch.ones(1, dtype=torch.int, device=data.t.device),
]
),
)
Here I assume that data is a TemporalData
-object from PyG and the edge index is sorted by time. I also assume the timestamps to be integers and remap them to start at 0 for simplicity. We could also start at any time stamp but this could potentially waste a lot of memory since we would save a ptr
for potentially many timestamps that never occur.
With the code above, get_snapshot
could look as follows:
def gets_snapshot(start, end):
return data[data.ptr[start]:data.ptr[end]]
We could also (if we do not want to reindex but still start at the minimum) do this:
def gets_snapshot(start, end):
return data[data.ptr[start-data_min_t]:data.ptr[end-data_min_t]]
By trading of runtime for memory efficiency, we could also search for the correct pointers in the sorted timestamp tensor t
. I think if only used once, this is probably preferable, but if we want to use this as an iterator with a rolling time window, the first approach could save a lot of time for large datasets.
So there is a PyG method TemporalData.snapshot(...)
which should work the same way as our intended get_snapshot(...)
, but it was not working for me when I tried it. Now I found the reason why: https://github.com/pyg-team/pytorch_geometric/issues/3230
I.e. snapshot(...)
is implemented in Data
and not TemporalData
. TemporalData
will not be supported for much longer and will be deprecated in the future. This is also the reason why sort_by_time
and other time-related methods did not work for me before when tested with TemporalData
because TemporalData
just inherited it from the implementation in Data
.
Long Story Short: We can do something like this:
def gets_snapshot(start, end):
return data.snapshot(start, end) # data needs to be a PyG Data object, not a TemporalData object!
As I understand both methods,
get_snapshot
should return aTemporalGraph
that contains all edges from the original graph that occur in the time window fromstart
toend
. This means that the returned edge index can potentially have very different sizes. See the following example for the expected behaviour:get_window(...)
instead should return a window of fixed sizeend-start
that contains the edges starting at indexstart
and ending (non-inclusive) at indexend
in an edge index sorted by time. Thus, e.g.get_window(0,5)
would return the first 5 events. The current implementation ofget_window
does exactly this butget_snapshot(...)
also does this.ALSO: For
get_window
to consistently work, we have to ensure to keep anedge_index
sorted by time. This is not the case after applyingshuffle_time
in its current implementation.