benedekrozemberczki / pytorch_geometric_temporal

PyTorch Geometric Temporal: Spatiotemporal Signal Processing with Neural Machine Learning Models (CIKM 2021)
MIT License
2.61k stars 367 forks source link

building dataset for temporal graph library with when the input graph has different graph structure as output one #207

Closed neuronphysics closed 1 year ago

neuronphysics commented 1 year ago

Hello, I am a newbie using this library. I have an input graph with node features which are evolving through time and I have an output graph which has different structure (different number of nodes and edges from input graph), and only its features would change over time. I wrote this dataset class:

import torch
import numpy as np
from networkx import from_numpy_array, from_numpy_matrix
from torch_geometric_temporal.signal import temporal_signal_split
from torch_geometric.utils import dense_to_sparse
from torch_geometric_temporal.signal import StaticHeteroGraphTemporalSignal
import networkx as nx
import pandas as pd
def adjacency_matrix_transition(node_state_type_size, node_action_type_size):
    ns = node_state_type_size
    na = node_action_type_size
    A  = np.zeros((ns+na,ns+na)).astype(np.float32)
    for i in range(A.shape[0]):
        for j in range(A.shape[1]):
            if i!=j:
                 if i <j:
                    if i<ns and j<ns:
                        A[i,j]=A[j,i]=1 #state
                    elif i<ns and j>=ns:
                        A[i,j]=A[j,i]=0.6 #state-action
                    elif i>=ns and j>=ns:
                        A[i,j]=A[j,i]=0.9 #action
    return A

class GraphDataset(object):
    def __init__(self,
                 input_name,
                 target_name,
                 obs_size,
                 a_size,
                 r_size,
                 filepath=None,
                 max_length=90,
                 device=torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
                 ):
        super(GymGraphDataset, self).__init__()
        self.a_size = a_size
        self.obs_size = obs_size
        self.r_size = r_size
        self.input_filename = input_name
        self.target_filename = target_name
        self.filepath = filepath
        self.max_length = max_length
        self.device = device
        self._read_data()

    def _read_data(self):
        #input shape(Batch, seq_length_input_size)
        #
        self.X = torch.load(os.path.join(self.filepath, self.input_filename))[:,:self.max_length,:]
        self.Y = torch.load(os.path.join(self.filepath, self.target_filename))[:,:self.max_length,:]
        self.X = self.X.permute(0,2,1).unsqueeze(-2)## (B, N, F, T)
        self.Y = self.Y.permute(0,2,1)
        print(f"input data {self.X.shape}")
        print(f"target data {self.Y.shape}")

    def _get_edges_and_weights(self):
        self.edge_index_dicts=dict()
        self.edge_weights_dicts=dict()
        self.input_graph = torch.from_numpy(adjacency_matrix_transition(self.obs_size,self.a_size))
        edge_indices, values = dense_to_sparse(self.input_graph)
        edge_indices = edge_indices.numpy()
        values = values.numpy()
        self.edge_index_dicts["features"] = edge_indices
        self.edge_weights_dicts["features"] = values
        graph = from_numpy_array(self.input_graph.numpy())
        edge_list = nx.to_pandas_edgelist(graph)
        adj = nx.from_pandas_edgelist(edge_list, edge_attr=['weight'])
        adj = pd.DataFrame(nx.adjacency_matrix(adj, weight='weight').todense())
        adjacency = adj.copy()

        D = np.diag(np.sum(adjacency,axis=1))
        I = np.identity(adjacency.shape[0])
        D_inv_sqrt = np.linalg.inv(np.sqrt(D))
        L = I - np.dot(D_inv_sqrt, adjacency).dot(D_inv_sqrt)

        self.normalized_laplacian_input = L.copy()
        ##
        self.target_graph= torch.from_numpy(adjacency_matrix_transition(self.obs_size,self.r_size))
        edge_indices, values = dense_to_sparse(self.target_graph)
        edge_indices = edge_indices.numpy()
        values = values.numpy()
        self.edge_index_dicts["targets"] = edge_indices
        self.edge_weights_dicts["targets"] = values
        graph = from_numpy_array(self.target_graph.numpy())
        edge_list = nx.to_pandas_edgelist(graph)
        adj = nx.from_pandas_edgelist(edge_list, edge_attr=['weight'])#,source=['source'], target=['target'])
        adj = pd.DataFrame(nx.adjacency_matrix(adj, weight='weight').todense())
        adjacency = adj.copy()

        D = np.diag(np.sum(adjacency,axis=1))
        I = np.identity(adjacency.shape[0])
        D_inv_sqrt = np.linalg.inv(np.sqrt(D))
        L = I - np.dot(D_inv_sqrt, adjacency).dot(D_inv_sqrt)

        self.normalized_laplacian_output = L.copy()

    def _get_targets_and_features(self, num_timesteps_in: int = 12, num_timesteps_out: int = 12):

        indices=[i
                 for i in range(self.X.shape[-1] - (num_timesteps_in + num_timesteps_out) + 1)]

        self.features, self.targets=[],[]
        for i in indices:
            self.features.append({"features": self.X[:,:,:, i : i + num_timesteps_in].numpy()})
            self.targets.append({"targets":self.Y[:,:,i : i + num_timesteps_in].numpy()})

    def get_dataset(self, num_timesteps_in: int = 12, num_timesteps_out: int = 12) -> StaticHeteroGraphTemporalSignal:
        """Returns data iterator as an instance of the static graph temporal signal class."""
        self._get_edges_and_weights()
        self._get_targets_and_features(num_timesteps_in, num_timesteps_out)
        dataset = StaticHeteroGraphTemporalSignal(
                                                  self.edge_index_dicts,
                                                  self.edge_weights_dicts,
                                                  self.features,
                                                  self.targets
                                                  )
        return dataset

def create_dataloader(train_dataset, DEVICE, bs, shuffle):
    train_input = np.array(train_dataset.features)
    train_target = np.array(train_dataset.targets)
    train_x_tensor = torch.from_numpy(train_input).type(torch.FloatTensor).to(DEVICE)  # (B, N, F, T)
    train_target_tensor = torch.from_numpy(train_target).type(torch.FloatTensor).to(DEVICE)  # (B, N, T)
    train_dataset_new = torch.utils.data.TensorDataset(train_x_tensor, train_target_tensor)
    train_loader = torch.utils.data.DataLoader(train_dataset_new, batch_size=bs, shuffle=shuffle, drop_last=True)
    return train_loader

DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")  # cuda
print(DEVICE)
shuffle = False
obs_size, a_size, r_size= 11, 3, 1
batch_size = 16
d_e=7*2
d_d=7
times = 20
epoches = 40

# Read data

input_name = "transition_inputs.pt"
target_name = "transition_outputs.pt"
filepath ="/home"
loader = GraphDataset(input_name, target_name, obs_size, a_size, r_size, filepath)
dataset = loader.get_dataset(num_timesteps_in=d_e, num_timesteps_out=d_d)
print(next(iter(dataset)))  # Show first sample
A_tilde = torch.from_numpy(loader.normalized_laplacian_input).to(DEVICE)
feat_df=pd.DataFrame.from_records(loader.features)

# Test Train Split
train_dataset, test_dataset = temporal_signal_split(dataset, train_ratio=t_ratio)
valid_dataset, test_dataset = temporal_signal_split(test_dataset, train_ratio=0.5)
print("Number of train data: ", len(list(train_dataset)))
print("Number of valid data: ", len(list(valid_dataset)))
print("Number of test data: ", len(list(test_dataset)))
n_feature = len(dataset.feature_dicts)  # Unstack
print(n_feature)
s_feature = n_feature * len(dataset.feature_dicts[0]["features"][0])  # Stack
print(s_feature)
n_node = len(dataset.feature_dicts[0])
print(n_node)
n_static = X_s.shape[1]

# Creat batches
train_loader = create_dataloader(train_dataset, DEVICE, batch_size, shuffle)
valid_loader = create_dataloader(valid_dataset, DEVICE, batch_size, shuffle)
test_loader = create_dataloader(test_dataset, DEVICE, batch_size, shuffle)

# Static network, get edge features
for snapshot in train_dataset:
    static_edge_index = snapshot.edge_index.to(DEVICE)
    static_edge_weight = snapshot.edge_attr.to(DEVICE)
    break

My questions:

Thanks.