graphdeeplearning / graphtransformer

Graph Transformer Architecture. Source code for "A Generalization of Transformer Networks to Graphs", DLG-AAAI'21.
https://arxiv.org/abs/2012.09699
MIT License
889 stars 137 forks source link

Error in Using this Graph Transformer Layer on random graph #27

Closed Akash82228 closed 8 months ago

Akash82228 commented 8 months ago

Hi! I am trying to use this Graph transformer layer on the random graph (see below). The error is occurring KeyError: 'wV'

import torch import dgl import networkx as nx from model import GraphTransformerLayer (# this imports the layers/graph_transformer_layer.py )

torch.manual_seed(42)

def create_random_graph(num_nodes, node_feature_dim): g = dgl.DGLGraph() g.add_nodes(num_nodes) node_features = torch.randn(num_nodes, node_feature_dim) g.ndata['feat'] = node_features return g

num_nodes = 10 node_feature_dim = 16 num_heads = 4

random_graph = create_random_graph(num_nodes, node_feature_dim)

model_layer = GraphTransformerLayer(in_dim=node_feature_dim, out_dim=node_feature_dim, num_heads=num_heads)

output_features = model_layer(random_graph, random_graph.ndata['feat'])

Akash82228 commented 8 months ago
Screenshot 2024-03-06 at 19 27 12