pyg-team / pytorch_geometric

Graph Neural Network Library for PyTorch
https://pyg.org
MIT License
21.34k stars 3.66k forks source link

IndexError: Encountered an index error. Please ensure that all indices in 'edge_index' point to valid indices in the interval [0, X] (got interval [Y, Z]) #8630

Closed tekeburak closed 10 months ago

tekeburak commented 10 months ago

🐛 Describe the bug

I'm encountering an index error while attempting to train the GraphSAGE model using batch processing. When inspecting my batch data one by one. data[0,...,n].edge_index, I found that the maximum value in edge_index isn't greater than num_nodes. However, upon checking my data as batch using data.edge_index, I noticed that its maximum value exceeds num_nodes. I think that this inconsistency seems to be causing issues during batch training.

The error arises at the line where I call out = model(data.x, data.edge_index). It seems that the model receives edge indices as a batch, resulting in an index error. I suspect that the discrepancy in edge indices when accessing as batch.

I have not more than 40 in edge index. But when I access it as batch, it shows me numbers that are greater than 40. This is not happening in y.

How can I overcome the index error I specified? Help needed @rusty1s.

data[0].edge_index
tensor([[19, 31, 30, 28, 19, 19,  6, 12, 24, 28, 28, 28, 19, 19, 36,  6, 19, 18],
        [19, 24, 19, 15, 28, 28, 18,  6, 12, 15, 31, 15, 28,  6, 28, 19, 18, 30]])

data[1].edge_index
tensor([[31, 19, 19, 19, 34,  5,  5, 19, 19, 18, 36, 19,  5],
        [19, 18, 31, 19,  5, 15, 15, 34,  5,  5, 19, 18, 15]])

data[2].edge_index
tensor([[ 7, 28, 36],
        [15,  7, 28]])

data[15].edge_index
tensor([[ 5, 36],
        [15,  5]])

data[18].edge_index
tensor([[36, 28,  5],
        [28,  5, 15]])

data[25].edge_index
tensor([[ 5, 35, 31, 11, 36],
        [35, 15, 11,  5, 31]])

data.edge_index

tensor([[ 19,  31,  30,  28,  19,  19,   6,  12,  24,  28,  28,  28,  19,  19,
          36,   6,  19,  18,  41,  29,  29,  29,  44,  15,  15,  29,  29,  28,
          46,  29,  15,  24,  45,  53,  57,  32,  40,  39,  49,  34,  52,  40,
          34,  48,  48,  48,  60,  65,  48,  48,  40,  42,  48,  47,  47,  40,
          42,  60,  48,  48,  65,  62,  68,  56,  68,  49,  56,  56,  68,  56,
          73,  49,  62,  56,  68,  68,  56,  68,  78,  78,  63,  72,  80,  63,
          85,  57,  88,  63,  60,  74,  66,  91,  74,  86,  80,  92,  80,  72,
          97,  79,  91, 103,  75, 106,  92, 101,  78, 104, 101,  85, 101,  92,
         109,  92, 101,  78,  92,  98,  85,  99,  85,  99,  85,  98,  99,  85,
          85,  98, 116,  98,  99,  99,  85,  98,  90, 121,  93, 119, 116, 119,
         124,  98, 129, 124, 124, 104, 124, 124, 134, 126, 103, 107, 138, 110,
         141, 144, 119, 127, 139, 143, 127, 126, 136, 127, 127, 113, 142, 108,
         150, 147, 155, 151, 133, 133, 135, 151, 151, 127, 151, 133, 133, 147,
         133, 151, 128, 151, 159, 151, 133, 167, 162, 162, 139, 169, 165, 145,
         170, 150, 150, 168, 148, 168, 145, 145, 176, 147, 150, 168, 168, 159,
         148, 178, 148, 168, 151, 148, 159, 184, 184, 166, 169, 155, 178, 168,
         184, 178, 157, 181, 169, 168, 169, 169, 157, 169, 162, 169, 156, 177,
         169, 181, 156, 155, 186, 174, 169, 168, 169, 162, 169, 187, 169, 169,
         169, 184, 155, 169, 168, 169, 184, 168, 200, 201, 187, 192, 204, 187,
         203, 208, 200, 207, 189, 183, 212, 204, 181],
        [ 19,  24,  19,  15,  28,  28,  18,   6,  12,  15,  31,  15,  28,   6,
          28,  19,  18,  30,  29,  28,  41,  29,  15,  25,  25,  44,  15,  15,
          29,  28,  25,  32,  24,  45,  52,  36,  34,  32,  40,  39,  49,  39,
          44,  48,  42,  47,  48,  60,  48,  47,  47,  40,  42,  60,  34,  47,
          40,  48,  47,  47,  56,  52,  68,  68,  68,  68,  49,  49,  52,  68,
          65,  56,  52,  68,  62,  52,  56,  62,  59,  59,  78,  63,  72,  78,
          57,  64,  63,  67,  70,  66,  70,  86,  60,  74,  72,  80,  79,  79,
          92,  76,  82,  91,  85,  75,  78,  88, 101,  92,  88, 101,  88, 104,
          85, 101,  92, 101,  78,  98,  85,  98,  85,  98,  85,  85,  98,  98,
          98,  95,  99,  98,  98,  85,  85,  99, 100,  90, 103, 119,  93, 116,
         119, 108, 104,  98, 124, 124, 124, 124, 126, 103, 113, 117, 107, 120,
         110, 119, 127, 126, 108, 123, 136, 143, 113, 142, 126, 139, 123, 127,
         147, 134, 150, 151, 133, 133, 133, 151, 151, 151, 138, 133, 147, 128,
         133, 151, 127, 151, 135, 151, 133, 162, 146, 162, 169, 149, 145, 139,
         165, 150, 150, 147, 148, 168, 145, 148, 159, 168, 168, 151, 145, 150,
         148, 155, 168, 155, 168, 148, 178, 165, 165, 169, 174, 178, 168, 184,
         165, 157, 184, 177, 168, 166, 157, 184, 184, 156, 181, 168, 168, 187,
         184, 162, 168, 155, 181, 165, 155, 169, 156, 162, 169, 169, 169, 169,
         168, 165, 178, 168, 169, 168, 165, 155, 180, 200, 183, 183, 187, 192,
         187, 200, 203, 189, 204, 191, 207, 181, 183]])

Encountered an index error. Please ensure that all indices in 'edge_index' point to valid indices in the interval [0, 171] (got interval [9, 200]).

My code snippet is below.

import torch
from torch_geometric.data import Data, Batch
from torch_geometric.loader import DataLoader
from torch_geometric.nn import GraphSAGE

train_data = Batch.from_data_list(train_dataset)
loader = DataLoader(train_data, batch_size=32, shuffle=True)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = GraphSAGE(
    in_channels=9,
    hidden_channels=64,
    num_layers=2,
    out_channels=40,
).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.005)
criterion = torch.nn.NLLLoss()

def train():
    model.train()

    total_loss = total_examples = 0
    for data in loader:
        data = data.to(device)
        optimizer.zero_grad()
        out = model(data.x, data.edge_index)
        loss = criterion(out, data.y)
        loss.backward()
        optimizer.step()

        total_loss += loss.item()
        total_examples += 1

    return total_loss / total_examples if total_examples > 0 else 0.0

for epoch in range(1, 6):
    train_loss = train()
    print(f'Epoch: {epoch:02d}, Train Loss: {train_loss:.4f}')

Versions

Collecting environment information... PyTorch version: 2.1.2 Is debug build: False CUDA used to build PyTorch: None ROCM used to build PyTorch: N/A

OS: macOS 14.1.1 (arm64) GCC version: Could not collect Clang version: 15.0.0 (clang-1500.0.40.1) CMake version: Could not collect Libc version: N/A

Python version: 3.8.18 (default, Sep 11 2023, 08:17:16) [Clang 14.0.6 ] (64-bit runtime) Python platform: macOS-14.1.1-arm64-arm-64bit Is CUDA available: False CUDA runtime version: No CUDA CUDA_MODULE_LOADING set to: N/A GPU models and configuration: No CUDA Nvidia driver version: No CUDA cuDNN version: No CUDA HIP runtime version: N/A MIOpen runtime version: N/A Is XNNPACK available: True

CPU: Apple M2

Versions of relevant libraries: [pip3] numpy==1.22.4 [pip3] torch==2.1.2 [pip3] torch_geometric==2.4.0 [pip3] torch-scatter==2.1.2 [pip3] torchaudio==0.13.1 [pip3] torchvision==0.14.1 [conda] numpy 1.22.4 pypi_0 pypi [conda] numpy-base 1.24.3 py38h90707a3_0 [conda] torch 2.1.2 pypi_0 pypi [conda] torch-geometric 2.4.0 pypi_0 pypi [conda] torch-scatter 2.1.2 pypi_0 pypi [conda] torchaudio 0.13.1 py38_cpu pytorch [conda] torchvision 0.14.1 py38_cpu pytorch

tekeburak commented 10 months ago

I tried another model to process batch in differently in model. I got same index error in forward pass the line below.

x = self.conv1(x, edge_index)

When I checked edge index has values greater than 182 leads the error. But it should not have it.

batch.shape
torch.Size([181])

edge_index
tensor([[  5,  36,  31,  31,  11,  31,  31,  24,  29,  41,  24,  45,  40,  40,
          17,  31,  31,  31,  43,  48,  31,  31,  23,  25,  31,  30,  30,  23,
          25,  43,  31,  31,  39,  51,  39,  31,  56,  38,  31,  62,  65,  37,
          37,  68,  70,  71,  57,  66,  43,  69,  66,  50,  66,  57,  74,  57,
          66,  43,  57,  50,  76,  73,  76,  81,  81,  63,  57,  86,  78,  55,
          62,  93,  94,  94,  79,  88,  96,  79,  99,  99,  81,  84,  70,  93,
          83,  99,  93,  72,  96,  84,  83,  84,  84,  72,  84,  77,  84,  71,
          92,  84,  96,  71,  70, 101,  89,  84,  83,  84,  77,  84, 102,  84,
          84,  84,  99,  70,  84,  83,  84,  99,  83, 108,  90,  90,  92, 108,
         108,  84, 108,  90,  90, 104,  90, 108,  85, 108, 116, 108,  90,  93,
         107,  99, 124, 107, 119, 129, 130, 133, 125, 102, 132, 120, 120, 120,
         135, 106, 106, 120, 120, 119, 137, 120, 106, 139, 136, 144, 148, 123,
         131, 130, 140, 125, 143, 131, 125, 156, 146, 148, 161, 136, 144, 156,
         160, 144, 143, 153, 144, 144, 130, 159, 125, 141, 172, 157, 144, 158,
         144, 158, 144, 157, 158, 144, 144, 157, 175, 157, 158, 158, 144, 157,
         168, 180, 175, 172, 178, 166, 178, 159, 166, 166, 178, 166, 183, 159,
         172, 166, 178, 178, 166, 178, 164, 164, 182, 162, 182, 159, 159, 190,
         161, 164, 182, 182, 173, 162, 192, 162, 182, 165, 162, 173, 183, 195,
         194, 192, 183, 183, 170, 176, 188, 192, 192, 192, 183, 183, 200, 170,
         183, 182, 205, 210, 202, 214, 189],
        [ 15,  11,   5,  31,  31,  31,  31,  20,  20,  24,  29,  40,  24,  40,
          27,  31,  25,  30,  31,  43,  31,  30,  30,  23,  25,  43,  17,  30,
          23,  31,  30,  30,  31,  39,  38,  38,  51,  35,  41,  31,  37,  44,
          47,  37,  50,  70,  43,  53,  66,  57,  53,  66,  53,  69,  50,  66,
          57,  66,  43,  60,  76,  50,  73,  76,  63,  78,  65,  81,  55,  57,
          72,  62,  75,  75,  94,  79,  88,  94,  80,  80,  84,  89,  93,  83,
          99,  80,  72,  99,  92,  83,  81,  72,  99,  99,  71,  96,  83,  83,
         102,  99,  77,  83,  70,  96,  80,  70,  84,  71,  77,  84,  84,  84,
          84,  83,  80,  93,  83,  84,  83,  80,  70, 108,  90,  90,  90, 108,
         108, 108,  95,  90, 104,  85,  90, 108,  84, 108,  92, 108,  90, 103,
          99, 103, 119,  93, 107, 109, 129, 125, 102, 112, 120, 119, 132, 120,
         106, 116, 116, 135, 106, 106, 120, 119, 116, 136, 123, 139, 143, 127,
         125, 123, 131, 130, 140, 130, 146, 148, 135, 125, 136, 144, 143, 125,
         140, 153, 160, 130, 159, 143, 156, 140, 144, 151, 141, 157, 144, 157,
         144, 157, 144, 144, 157, 157, 157, 154, 158, 157, 157, 144, 144, 158,
         159, 168, 166, 162, 178, 178, 178, 178, 159, 159, 162, 178, 175, 166,
         162, 178, 172, 162, 166, 172, 164, 164, 161, 162, 182, 159, 162, 173,
         182, 182, 165, 159, 164, 162, 169, 182, 169, 182, 162, 192, 183, 188,
         183, 179, 192, 192, 182, 170, 176, 179, 195, 179, 192, 170, 192, 183,
         182, 194, 189, 202, 205, 189, 193]])
Exception has occurred: RuntimeError
index 182 is out of bounds for dimension 0 with size 181
from torch.nn import Linear
import torch.nn.functional as F
from torch_geometric.nn import GCNConv
from torch_geometric.nn import global_mean_pool

class GCN(torch.nn.Module):
    def __init__(self, hidden_channels):
        super(GCN, self).__init__()
        torch.manual_seed(12345)
        self.conv1 = GCNConv(9, hidden_channels)
        self.conv2 = GCNConv(hidden_channels, hidden_channels)
        self.conv3 = GCNConv(hidden_channels, hidden_channels)
        self.lin = Linear(hidden_channels, 40)

    def forward(self, x, edge_index, batch):
        # 1. Obtain node embeddings 
        x = self.conv1(x, edge_index)
        x = x.relu()
        x = self.conv2(x, edge_index)
        x = x.relu()
        x = self.conv3(x, edge_index)

        # 2. Readout layer
        x = global_mean_pool(x, batch)  # [batch_size, hidden_channels]

        # 3. Apply a final classifier
        x = F.dropout(x, p=0.5, training=self.training)
        x = self.lin(x)

        return x

model = GCN(hidden_channels=64)
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
criterion = torch.nn.CrossEntropyLoss()

def train():
    model.train()

    for data in loader:  # Iterate in batches over the training dataset.
         out = model(data.x, data.edge_index, data.batch)  # Perform a single forward pass.
         loss = criterion(out, data.y)  # Compute the loss.
         loss.backward()  # Derive gradients.
         optimizer.step()  # Update parameters based on gradients.
         optimizer.zero_grad()  # Clear gradients.

def test(loader):
     model.eval()

     correct = 0
     for data in loader:  # Iterate in batches over the training/test dataset.
         out = model(data.x, data.edge_index, data.batch)  
         pred = out.argmax(dim=1)  # Use the class with highest probability.
         correct += int((pred == data.y).sum())  # Check against ground-truth labels.
     return correct / len(loader.dataset)  # Derive ratio of correct predictions.

for epoch in range(1, 171):
    train()
    train_acc = test(train_loader)
    test_acc = test(test_loader)
    print(f'Epoch: {epoch:03d}, Train Acc: {train_acc:.4f}, Test Acc: {test_acc:.4f}')
EdisonLeeeee commented 10 months ago

Hi @tekeburak

Looks like your train_dataset is a list of Data objects, so the correct way is to call

loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
tekeburak commented 10 months ago

Thank you for your answer @EdisonLeeeee. When I changed my code as below, I got same error.

Exception has occurred: RuntimeError
index 182 is out of bounds for dimension 0 with size 181
# train_data = Batch.from_data_list(train_dataset)
# loader = DataLoader(train_data, batch_size=32, shuffle=True)

loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
EdisonLeeeee commented 10 months ago

Can you also do me a favor by running:

for data in train_dataset:
    data.validate()

for batch in loader:
    batch.validate()
tekeburak commented 10 months ago

Sure! Here it is.

for data in train_dataset:
    data.validate()
data.validate()
ValueError: 'edge_index' contains larger indices than the number of nodes (10) in 'Data' (found 36)
data.edge_index
tensor([[19, 31, 30, 28, 19, 19,  6, 12, 24, 28, 28, 28, 19, 19, 36,  6, 19, 18],
        [19, 24, 19, 15, 28, 28, 18,  6, 12, 15, 31, 15, 28,  6, 28, 19, 18, 30]])
data.y
tensor([19, 31, 24, 30, 28, 15,  6, 18, 12, 36])

--

for batch in loader:
    batch.validate()
batch.validate()
ValueError: 'edge_index' contains larger indices than the number of nodes (183) in 'DataBatch' (found 213)
batch.edge_index
tensor([[  5,  31,  28,  31,  36,  36,  24,  24,  24,  39,  10,  10,  24,  24,
          23,  41,  24,  10,  31,  43,  31,  23,  48,  30,  28,  28,  46,  26,
          46,  23,  23,  54,  25,  28,  46,  46,  37,  26,  56,  26,  46,  29,
          26,  37,  33,  64,  50,  55,  67,  50,  40,  71,  56,  43,  57,  43,
          57,  43,  56,  57,  43,  43,  56,  74,  56,  57,  57,  43,  56,  79,
          74,  74,  77,  82,  74,  85,  86,  60,  81,  89,  88,  85,  93,  66,
          80,  72,  97,  80,  92, 103,  78,  86,  98, 102,  86,  85,  95,  86,
          86,  72, 101,  67, 106,  88,  88,  90, 106, 106,  82, 106,  88,  88,
         102,  88, 106,  83, 106, 114, 106,  88, 122,  97, 120, 102,  96, 125,
         117,  94, 101, 115, 115, 115, 127, 132, 115, 115, 107, 109, 115, 114,
         114, 107, 109, 127, 115, 115, 109, 140, 126, 138, 137, 135, 126, 126,
         113, 119, 131, 135, 135, 135, 126, 126, 143, 113, 126, 125, 153, 128,
         136, 135, 145, 130, 148, 136, 159, 159, 141, 144, 130, 153, 143, 159,
         153, 132, 156, 144, 143, 144, 144, 132, 144, 137, 144, 131, 152, 144,
         156, 131, 130, 161, 149, 144, 143, 144, 137, 144, 162, 144, 144, 144,
         159, 130, 144, 143, 144, 159, 143, 176, 168, 145, 168, 180, 159, 183,
         155, 186, 176, 178, 183, 180, 186, 174, 186, 167, 174, 174, 186, 174,
         191, 167, 180, 174, 186, 186, 174, 186, 167, 198, 193, 193, 173, 193,
         193, 202, 203, 189, 198, 175, 201, 198, 182, 198, 189, 206, 189, 198,
         175, 189, 182, 212, 208, 188, 213],
        [ 15,  31,   5,  28,  31,  24,  23,  36,  24,  10,  20,  20,  39,  10,
          10,  24,  23,  20,  23,  31,  30,  30,  43,  27,  28,  28,  25,  26,
          46,  23,  26,  37,  46,  46,  29,  23,  28,  26,  33,  46,  33,  46,
          26,  56,  43,  33,  46,  46,  50,  55,  50,  40,  56,  43,  56,  43,
          56,  43,  43,  56,  56,  56,  53,  57,  56,  56,  43,  43,  57,  74,
          58,  74,  61,  74,  77,  65,  85,  68,  60,  81,  85,  72,  88,  76,
          72,  76,  92,  66,  80,  78,  86,  85,  67,  82,  95, 102,  72, 101,
          85,  98,  82,  86, 106,  88,  88,  88, 106, 106, 106,  93,  88, 102,
          83,  88, 106,  82, 106,  90, 106,  88,  97, 101, 102, 117, 104, 120,
          94,  96, 111, 115, 109, 114, 115, 127, 115, 114, 114, 107, 109, 127,
         101, 114, 107, 115, 114, 114, 119, 109, 126, 131, 126, 122, 135, 135,
         125, 113, 119, 122, 138, 122, 135, 113, 135, 126, 125, 137, 148, 132,
         130, 128, 136, 135, 145, 135, 140, 140, 144, 149, 153, 143, 159, 140,
         132, 159, 152, 143, 141, 132, 159, 159, 131, 156, 143, 143, 162, 159,
         137, 143, 130, 156, 140, 130, 144, 131, 137, 144, 144, 144, 144, 143,
         140, 153, 143, 144, 143, 140, 130, 168, 145, 155, 159, 168, 162, 159,
         176, 178, 165, 155, 174, 170, 186, 186, 186, 186, 167, 167, 170, 186,
         183, 174, 170, 186, 180, 170, 174, 180, 177, 173, 167, 193, 193, 193,
         193, 182, 202, 175, 185, 198, 189, 185, 198, 185, 201, 182, 198, 189,
         198, 175, 212, 192, 188, 182, 208]])
batch.y
tensor([ 5, 15, 31, 28, 36, 31, 19, 18, 34,  5, 15, 36, 19, 11, 31, 18, 36, 15,
        10, 28,  7,  8,  5, 36, 19, 11, 38, 15,  5, 15, 36, 19, 15, 24, 36,  5,
        15, 36, 18,  5, 19, 15, 36, 36, 31, 15, 31, 15, 36, 28, 35, 15, 36,  7,
        15, 28, 36, 31, 28, 15, 36,  5, 15, 19, 11, 36, 31, 36, 11, 19, 18, 31,
         0, 35, 15, 28,  5, 34, 28, 10, 12,  4, 15, 24,  5, 36, 36, 11, 15, 31,
        13, 28,  7, 15, 36,  5,  5, 15, 19, 13, 18, 31, 36, 11,  5, 15, 36, 19,
        31, 24, 30, 28, 15,  6, 18, 12, 36, 36, 31, 11, 15, 19, 13, 18, 28, 34,
        15, 16, 19, 24,  5, 28, 18,  7, 31, 27,  6, 12, 37, 36, 36, 28,  5, 15,
        24, 15, 36, 12, 15, 36,  5, 26, 36, 28, 15, 28, 19, 25, 15, 31, 12, 36,
         5, 15, 36, 11, 31, 35, 15, 36, 19,  5, 28, 15, 31, 12, 36,  5, 35, 15,
        31, 11, 36])
EdisonLeeeee commented 10 months ago

So your original data, train_dataset, contains illegal edge indices, which leads to the error in the dataloader. How's your train_dataset organized?

tekeburak commented 10 months ago

My train dataset is list of torch_geometric.data.Data objects. Each list element is actually independent graph. I have more than 500 independent graphs and I want to do node classification in graph level. But all of the graphs have totally 40 nodes [N0, N1, N2, ..., N39]. For example, first graph has only 10 nodes but these node labels: [N19, N31, N24, N30, N28, N15, N6, N18, N12, N36]. The labels here actually causes an index error. I'm stuck on how to handle the data actually.

Any suggestions would be highly appreciated.

train_dataset [Data(x=[10, 9], edge_index=[2, 18], y=[10], transform=NormalizeFeatures()), Data(x=[6, 9], edge_index=[2, 5], y=[6], transform=NormalizeFeatures()), Data(x=[5, 9], edge_index=[2, 5], y=[5], transform=NormalizeFeatures()), Data(x=[4, 9], edge_index=[2, 4], y=[4], transform=NormalizeFeatures()), Data(x=[3, 9], edge_index=[2, 2], y=[3], transform=NormalizeFeatures()), Data(x=[3, 9], edge_index=[2, 3], y=[3], transform=NormalizeFeatures()), Data(x=[3, 9], edge_index=[2, 2], y=[3], transform=NormalizeFeatures()), Data(x=[5, 9], edge_index=[2, 17], y=[5], transform=NormalizeFeatures()), Data(x=[7, 9], edge_index=[2, 13], y=[7], transform=NormalizeFeatures()), Data(x=[3, 9], edge_index=[2, 2], y=[3], transform=NormalizeFeatures()), Data(x=[5, 9], edge_index=[2, 6], y=[5], transform=NormalizeFeatures()), Data(x=[4, 9], edge_index=[2, 3], y=[4], transform=NormalizeFeatures()), Data(x=[10, 9], edge_index=[2, 20], y=[10], transform=NormalizeFeatures()), Data(x=[11, 9], edge_index=[2, 13], y=[11], transform=NormalizeFeatures()), Data(x=[4, 9], edge_index=[2, 3], y=[4], transform=NormalizeFeatures()), Data(x=[8, 9], edge_index=[2, 8], y=[8], transform=NormalizeFeatures()), Data(x=[8, 9], edge_index=[2, 18], y=[8], transform=NormalizeFeatures()), Data(x=[15, 9], edge_index=[2, 43], y=[15], transform=NormalizeFeatures()), Data(x=[7, 9], edge_index=[2, 18], y=[7], transform=NormalizeFeatures()), Data(x=[7, 9], edge_index=[2, 6], y=[7], transform=NormalizeFeatures()), Data(x=[6, 9], edge_index=[2, 6], y=[6], transform=NormalizeFeatures()), Data(x=[4, 9], edge_index=[2, 3], y=[4], transform=NormalizeFeatures()), Data(x=[3, 9], edge_index=[2, 2], y=[3], transform=NormalizeFeatures()), Data(x=[3, 9], edge_index=[2, 2], y=[3], transform=NormalizeFeatures()), Data(x=[5, 9], edge_index=[2, 7], y=[5], transform=NormalizeFeatures()), Data(x=[5, 9], edge_index=[2, 4], y=[5], transform=NormalizeFeatures()), Data(x=[3, 9], edge_index=[2, 2], y=[3], transform=NormalizeFeatures()), Data(x=[8, 9], edge_index=[2, 18], y=[8], transform=NormalizeFeatures()), Data(x=[7, 9], edge_index=[2, 13], y=[7], transform=NormalizeFeatures()), Data(x=[3, 9], edge_index=[2, 2], y=[3], transform=NormalizeFeatures()), Data(x=[3, 9], edge_index=[2, 2], y=[3], transform=NormalizeFeatures()), Data(x=[3, 9], edge_index=[2, 2], y=[3], transform=NormalizeFeatures()), Data(x=[3, 9], edge_index=[2, 2], y=[3], transform=NormalizeFeatures()), Data(x=[6, 9], edge_index=[2, 6], y=[6], transform=NormalizeFeatures()), Data(x=[4, 9], edge_index=[2, 3], y=[4], transform=NormalizeFeatures())]

rusty1s commented 10 months ago

If all your graphs have 40 nodes, then x and y should have a shape of [40, *]. If you want to mask out certain labels/features since you don't have information about them, you can just pass in an all zero vector for these nodes as node features.

tekeburak commented 10 months ago

If all your graphs have 40 nodes, then x and y should have a shape of [40, *]. If you want to mask out certain labels/features since you don't have information about them, you can just pass in an all zero vector for these nodes as node features.

Thank you @rusty1s for your answer. I changed my dataset as you suggested. Now, I am able to train the model with batch.

But when I extract the node features for the nodes that are present in the current graph and leave the remaining ones' feature as zero, model training accuracy does not increase and it stucks at around %10. I am wonder that am I missing sth or doing sth wrong.

My training code.

import pickle
import networkx as nx
import pandas as pd
from torch_geometric.data import Data
from torch_geometric.transforms import NormalizeFeatures
import numpy as np
from natsort import natsorted
import torch
import torch.nn.functional as F
from torch_geometric.loader import DataLoader
from sklearn.model_selection import train_test_split
from torch_geometric.nn import GraphSAGE
from torch_geometric.utils.convert import from_networkx, to_networkx

training_file = "example_file.parquet"
df = pd.read_parquet(training_file, engine='fastparquet')
nodes_list_all_graph = natsorted(df['originNodeType'].append(df['destinationNodeType']).unique().tolist())

do_calculation = False
individual_graphs_data = []

if do_calculation:
    for k,v in df.groupby('graphId', sort=False):
        all_node_features = []
        all_node_labels = []
        all_edges_index = []

        G = nx.DiGraph()  # Create a new graph for this 'graphId'
        edges = [(src,
                  dst)
                 for src, dst in zip(v["originNodeType"].tolist(), v["destinationNodeType"].tolist())]
        G.add_edges_from(edges)

        nx.set_node_attributes(G,{n:{'label':n} for n in G.nodes()})
        G = from_networkx(G)
        graph = to_networkx(G)

        degree_centrality = nx.degree_centrality(graph)
        closeness_centrality = nx.closeness_centrality(graph)
        betweenness_centrality = nx.betweenness_centrality(graph)
        pagerank_centrality = nx.pagerank(graph)

        try:
            eigenvector_centrality = nx.eigenvector_centrality(graph, max_iter=10000)
        except nx.PowerIterationFailedConvergence:
            print("Convergence failed. Adjust algorithm parameters or check graph structure.")

        katz_centrality = nx.katz_centrality(graph)
        local_clustering_coefficient = nx.clustering(graph)
        node_degrees = dict(graph.degree())
        in_degrees = dict(graph.in_degree())
        out_degrees = dict(graph.out_degree())

        remaining_label = [item for item in nodes_list_all_graph if item not in G.label]
        for node_index, node_name in enumerate(G.label):
            features = [
                degree_centrality.get(node_index),
                closeness_centrality.get(node_index),
                betweenness_centrality.get(node_index),
                pagerank_centrality.get(node_index),
                eigenvector_centrality.get(node_index),
                katz_centrality.get(node_index),
                local_clustering_coefficient.get(node_index),
                node_degrees.get(node_index),
                in_degrees.get(node_index),
                out_degrees.get(node_index),
            ]
            all_node_features.append(features)
            all_node_labels.append(int(node_name[1:]))

        for remaining_node_name in remaining_label:
            all_node_features.append([0.0] * num_features)
            all_node_labels.append(int(remaining_node_name[1:]))

        x = torch.tensor(np.stack(all_node_features), dtype=torch.float)
        edge_index = G.edge_index
        y = torch.tensor(np.stack(all_node_labels), dtype=torch.long)
        dataset = Data(x=x, 
                       edge_index=edge_index, 
                       y=y)
        dataset.transform = NormalizeFeatures()

        individual_graphs_data.append(dataset)

    with open('individual_graphs.pkl', 'wb') as file:
        pickle.dump(individual_graphs_data, file)

num_classes = 40
num_features = 10
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

with open('individual_graphs.pkl', 'rb') as file:
    individual_graphs_data = pickle.load(file)

# Splitting data into train and remaining data
train_dataset, remaining_data = train_test_split(individual_graphs_data, train_size=0.7, random_state=42)

# Splitting remaining data into validation and test
val_dataset, test_dataset = train_test_split(remaining_data, train_size=0.5, random_state=42)

print(f"Number of training samples: {len(train_dataset)}")
print(f"Number of validation samples: {len(val_dataset)}")
print(f"Number of test samples: {len(test_dataset)}")

# DataLoader
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=64, shuffle=False)
test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False)

model = GraphSAGE(in_channels=num_features, hidden_channels=16, out_channels=num_classes, num_layers=2)
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
criterion = torch.nn.CrossEntropyLoss()

def train():
    model.train()
    running_loss = 0.0

    for data in train_loader:  # Iterate in batches over the training dataset.
         out = model(data.x, data.edge_index)  # Perform a single forward pass.
         loss = criterion(out, data.y)  # Compute the loss.
         loss.backward()  # Derive gradients.
         optimizer.step()  # Update parameters based on gradients.
         optimizer.zero_grad()  # Clear gradients.

         running_loss += loss.item()

    return running_loss / train_loader.__len__()

def test(loader):
     model.eval()

     correct = 0
     total = 0
     for data in loader:  # Iterate in batches over the training/test dataset.
         out = model(data.x, data.edge_index)  
         pred = out.argmax(dim=1)  # Use the class with highest probability.
         correct += int((pred == data.y).sum())  # Check against ground-truth labels.
         total += len(data.y)
     return correct / total  # Derive ratio of correct predictions.

best_val_acc = 0.0
best_model_state_dict = None
for epoch in range(1, 1001):
    loss = train()
    train_acc = test(train_loader)
    val_acc = test(val_loader)

    if val_acc > best_val_acc:
        best_val_acc = val_acc
        best_model_state_dict = model.state_dict()
    print(f'Epoch: {epoch:03d}, Train Loss: {loss:.4f}, Train Acc: {train_acc:.4f}, Val Acc: {val_acc:.4f}')

if best_model_state_dict:
    torch.save(best_model_state_dict, 'best_model.pth')

individual_graphs_data

Screenshot 2023-12-19 at 17 06 52
EdisonLeeeee commented 10 months ago

Hi @tekeburak I don't think you need to compute loss/accuracy on those "remaining nodes". You need an additional attribute like mask that indicates which nodes exist.

tekeburak commented 10 months ago

Thank you @EdisonLeeeee and @rusty1s. I did calculation excluding remaining nodes and I got more reasonable accuracy.

MoaazZaki commented 10 months ago

@tekeburak Can you please share the final snippet of the code ? I am running into a very similar issue..