benedekrozemberczki / pytorch_geometric_temporal

PyTorch Geometric Temporal: Spatiotemporal Signal Processing with Neural Machine Learning Models (CIKM 2021)
MIT License
2.64k stars 372 forks source link

Putting multiple graphs in one StaticGraphTemporalSignal iterator #84

Closed omarsinnno closed 3 years ago

omarsinnno commented 3 years ago

I'm dealing with a basketball players trajectory forecasting dataset; it is both temporal and spatial with the shape: [number of plays, number of timesteps, player, position]. Each play is composed of several timesteps, so in total I would have *number of plays number of timesteps** graphs.

From reading documentation and source code it seems that the StaticGraphTemporalSignal accepts multiple graphs, but I'm not sure how to do that. Moreover, I cannot create an instance of this iterator without passing target labels, which is not what I aim for in this project; unless I pass the target labels as the features of the next step in the prediction.

I'm not sure if my question is clear I would be glad to elaborate, I require help :)

benedekrozemberczki commented 3 years ago

Hi @omarsinnno,

Thank you for reaching out! In order to do this, you would need to do the diagonal batching trick - you can do this by creating a block diagonal matrix for each time step. This can be done using the batch iterator variants - these require batch membership indicators.

If you need help we can set up a short call where we deal with this.

Once again, thank you for reaching out. Feel free to start the repo and hit follow!

Bests,

Benedek

omarsinnno commented 3 years ago

Hello @benedekrozemberczki

Thank you for your timely reply! I'm not quite sure what you mean about the diagonal batching trick to use these, if you can provide an example that would be perfect. A call would be nice if you're available during the weekend, since I'm new to this!

benedekrozemberczki commented 3 years ago

https://github.com/tkipf/gcn/issues/4

benedekrozemberczki commented 3 years ago

Dear @omarsinnno,

Do you still need help with this? Do you want to write a custom data loader and share the dataset?

Bests,

Benedek

omarsinnno commented 3 years ago

Dear @benedekrozemberczki

I will upload one as soon as possible. I think i figured out the diagonalization trick, remains to know how to use batch static graph temporal signal so that data can be loaded into memory.

omarsinnno commented 3 years ago

Here is a data loader I made while reading the code for the chicken pox dataloader:

class BskDatasetLoader(object):
    def __init__(self, root):
        super(BskDatasetLoader, self).__init__()
        rp = osp.join(root, 'bsk_data.json')
        rpf = open(rp, 'r')
        self.data = json.load(rpf)

        if not(osp.isfile(osp.join(root, 'restruct_json.json'))):
            self.restruct_json= self.restructure_json()
            dj = open(osp.join(root, 'restruct_json.json'), 'w')
            json.dump(self.diag_json, dj)
        else:
            print('Loading restructured json...')
            self.restruct_json= json.load(open(osp.join(root, 'restruct_json.json'), 'r'))
            print('Restructured json loaded.')

    def restructure_json(self):
        dic = dict()
        edges_list = []
        fx_list = []
        for g, graph in tqdm(enumerate(self.data), desc='Restructuring JSON:', total=len(self.data.keys())):
            edges_list.append([])
            fx_list.append([])
            for t, ts in enumerate(self.data[graph]):
                edges = self.data[graph][ts]['edges']
                fx = self.data[graph][ts]['FX']
                edges_list[g].append(edges)
                fx_list[g].append(fx)

        dic['edges'] = edges_list
        dic['FX'] = fx_list

        return dic

    def _get_edges(self):
        self._edges = np.array(self.restruct_json['edges'])
        print('Got edges: --> shape{s}'.format(s=self._edges.shape))

    def _get_edge_weights(self):
        self.container = dict()
        weights_list = []
        for g, graph in tqdm(enumerate(self.data), desc='Generating edge weights', total=len(self.data.keys())):
            weights_list.append([])
            for t, ts in enumerate(self.data[graph]):
                num_edges = 11*(11-1)
                edge_weights = torch.ones((num_edges, 1), dtype=torch.long).tolist()
                weights_list[g].append(edge_weights)

        self.container['edge_weights'] = weights_list
        self._edge_weights = np.array(self.container['edge_weights'])
        print('Got edge weights: --> shape{s}'.format(s=self._edge_weights.shape))

    def _get_targets_and_features(self):
        stacked_target = np.array(self.restruct_json['FX'])
        self.features = [stacked_target[i][:self.lags][:][:].T for i in range(stacked_target.shape[0])]
        self.targets = [stacked_target[i][self.lags:][:][:].T for i in range(stacked_target.shape[0])]

        self.features = np.array(self.features)
        self.targets = np.array(self.targets)
        print('Got features: --> shape{s}'.format(s=self.features.shape))
        print('Got targets: --> shape{s}'.format(s=self.targets.shape))

    def get_dataset(self, lags: int=33) -> StaticGraphTemporalSignal:
        self.lags = lags
        self._get_edges()
        self._get_edge_weights()
        self._get_targets_and_features()
        dataset = StaticGraphTemporalSignal(self._edges, self._edge_weights, self.features, self.targets)
        return dataset

My apologies if the code is messy; bsk_data.json: JSON file having the following structure

play 0:  
        timestamp 0:  
                 edges: list of edges  
                 FX: features  
        timestamp 1:  
                  edges
                  FX
        timestamp 2: ...
play 1:  
        timestamp 0: ...

The rest of the methods are the same as the chicken pox dataloader, the _get_edge_weights is modified so that it creates one per timestamp and stores them in a dictionary self.container.

  1. I'm not sure how the diagonalization trick can be inserted into a chickenpox-like dataloader
  2. I guess I have to use a batch iterator signal because the data in question is large and the device I'm training on is reaching memory capacity.
benedekrozemberczki commented 3 years ago

Each temporal snapshot is a graph batch (e.g. like molecules).