pyg-team / pytorch_geometric

Graph Neural Network Library for PyTorch
https://pyg.org
MIT License
21.09k stars 3.63k forks source link

HeteroData IndexError when training. IndexError: Encountered an index error. Please ensure that all indices in 'edge_index' point to valid indices in the interval [0, 15] (got interval [0, 25]) #7949

Open TKForgeron opened 1 year ago

TKForgeron commented 1 year ago

๐Ÿ› Describe the bug

I made my own custom dataset, it consists of HeteroData graphs. This is what one such graph looks like:

HeteroData(
    event={
      x=[26, 14],
      y=[26]
    },
    krs={
      x=[1, 20],
      y=[1]
    },
    krv={
      x=[1, 20],
      y=[1]
    },
    cv={
      x=[1, 20],
      y=[1]
    },
    (event, follows, event)={ edge_index=[2, 51] },
    (krs, interacts, event)={ edge_index=[2, 6] },
    (krv, interacts, event)={ edge_index=[2, 21] },
    (cv, interacts, event)={ edge_index=[2, 1] },
    (krs, updates, krs)={ edge_index=[2, 1] },
    (krv, updates, krv)={ edge_index=[2, 1] },
    (cv, updates, cv)={ edge_index=[2, 1] }
)

I use the standard Dataloader to create batches of size 16.

I then train using this model:

class GNN(torch.nn.Module):
    def __init__(
        self,
        hidden_channels: int = 64,
        out_channels: int = 1,
        pre_forward_view: bool = False,
        squeeze: bool = True,
    ):
        super().__init__()
        self.squeeze = squeeze
        self.conv1 = pygnn.GraphConv(-1, hidden_channels)
        self.act1 = nn.PReLU()
        self.lin_out = pygnn.Linear(-1, out_channels)

    def forward(self, x, edge_index, batch=None):
        x = self.conv1(x, edge_index)
        x = self.act1(x)
        x = self.lin_out(x)
        x = torch.squeeze(x) if self.squeeze else x
        return x

Which I converted to a hetero one:

meta_data = (['event', 'krs', 'krv', 'cv'],
       [('event', 'follows', 'event'),
        ('krs', 'interacts', 'event'),
        ('krv', 'interacts', 'event'),
        ('cv', 'interacts', 'event'),
        ('krs', 'updates', 'krs'),
        ('krv', 'updates', 'krv'),
        ('cv', 'updates', 'cv')])
model = GNN()
model = to_hetero(model, meta_data)

I then got this lovely error:

---------------------------------------------------------------------------
IndexError                                Traceback (most recent call last)
File [~/Development/OCPPM/.env/lib/python3.9/site-packages/torch_geometric/nn/conv/message_passing.py:272](https://file+.vscode-resource.vscode-cdn.net/home/tim/Development/OCPPM/experiments/case_study/feature_encodings/hoeg/~/Development/OCPPM/.env/lib/python3.9/site-packages/torch_geometric/nn/conv/message_passing.py:272), in MessagePassing._lift(self, src, edge_index, dim)
    271     index = edge_index[dim]
--> 272     return src.index_select(self.node_dim, index)
    273 except (IndexError, RuntimeError) as e:

IndexError: index out of range in self

During handling of the above exception, another exception occurred:

IndexError                                Traceback (most recent call last)
Cell In[52], line 11
      9 for lr in lr_range:
     10     for hidden_dim in hidden_dim_range:
---> 11         hetero_experiment_utils.run_hoeg_experiment_configuration(
     12             HigherOrderGNN,
     13             lr=lr,
     14             hidden_dim=hidden_dim,
     15             train_loader=train_loader,
     16             val_loader=val_loader,
     17             test_loader=test_loader,
     18             hoeg_config=cs_hoeg_config,
     19         )

File [~/Development/OCPPM/utilities/hetero_experiment_utils.py:54](https://file+.vscode-resource.vscode-cdn.net/home/tim/Development/OCPPM/experiments/case_study/feature_encodings/hoeg/~/Development/OCPPM/utilities/hetero_experiment_utils.py:54), in run_hoeg_experiment_configuration(model_class, lr, hidden_dim, train_loader, val_loader, test_loader, hoeg_config)
     51 timestamp = start_train_time.strftime("%Y%m%d_%Hh%Mm")
     52 model_path_base = f"{hoeg_config['model_output_path']}[/lr](https://file+.vscode-resource.vscode-cdn.net/lr)={hoeg_config['optimizer_settings']['lr']}_hidden_dim={hoeg_config['hidden_dim']}[/](https://file+.vscode-resource.vscode-cdn.net/){str(model).split('(')[0]}_{timestamp}"
---> 54 best_state_dict_path = hetero_training_utils.run_training_hetero(
     55     target_node_type=hoeg_config["target_node_type"],
     56     num_epochs=hoeg_config["EPOCHS"],
     57     model=model,
     58     train_loader=train_loader,
     59     validation_loader=val_loader,
     60     optimizer=hoeg_config["optimizer"](
     61         model.parameters(), **hoeg_config["optimizer_settings"]
     62     ),
     63     loss_fn=hoeg_config["loss_fn"],
     64     early_stopping_criterion=hoeg_config["early_stopping"],
     65     model_path_base=model_path_base,
     66     device=hoeg_config["device"],
     67     verbose=hoeg_config["verbose"],
     68     squeeze_required=hoeg_config["squeeze"],
     69 )
     70 total_train_time = datetime.now() - start_train_time
     72 # Write experiment settings as JSON into model path (of the model we've just trained)

File [~/Development/OCPPM/utilities/hetero_training_utils.py:107](https://file+.vscode-resource.vscode-cdn.net/home/tim/Development/OCPPM/experiments/case_study/feature_encodings/hoeg/~/Development/OCPPM/utilities/hetero_training_utils.py:107), in run_training_hetero(target_node_type, num_epochs, model, train_loader, validation_loader, optimizer, loss_fn, early_stopping_criterion, device, model_path_base, verbose, squeeze_required)
    104 for epoch in range(num_epochs):
    105     # Make sure gradient tracking is on, and do a pass over the data
    106     model.train(True)
--> 107     avg_loss = train_one_epoch_hetero(
    108         target_node_type,
    109         epoch,
    110         model,
    111         train_loader,
    112         optimizer,
    113         loss_fn,
    114         writer,
    115         device,
    116         verbose,
    117         squeeze_required,
    118     )
    120     # We don't need gradients on to do reporting
    121     model.train(False)

File [~/Development/OCPPM/utilities/hetero_training_utils.py:53](https://file+.vscode-resource.vscode-cdn.net/home/tim/Development/OCPPM/experiments/case_study/feature_encodings/hoeg/~/Development/OCPPM/utilities/hetero_training_utils.py:53), in train_one_epoch_hetero(target_node_type, epoch_index, model, train_loader, optimizer, loss_fn, tb_writer, device, verbose, squeeze_required)
     51 optimizer.zero_grad(set_to_none=True)
     52 # Passing the node features and the connection info
---> 53 outputs = model(
     54     inputs, edge_index=adjacency_matrix  # , batch=batch[target_node_type].batch
     55 )
     56 # Compute loss and gradients
     57 if squeeze_required:

File [~/Development/OCPPM/.env/lib/python3.9/site-packages/torch/fx/graph_module.py:658](https://file+.vscode-resource.vscode-cdn.net/home/tim/Development/OCPPM/experiments/case_study/feature_encodings/hoeg/~/Development/OCPPM/.env/lib/python3.9/site-packages/torch/fx/graph_module.py:658), in GraphModule.recompile..call_wrapped(self, *args, **kwargs)
    657 def call_wrapped(self, *args, **kwargs):
--> 658     return self._wrapped_call(self, *args, **kwargs)

File [~/Development/OCPPM/.env/lib/python3.9/site-packages/torch/fx/graph_module.py:277](https://file+.vscode-resource.vscode-cdn.net/home/tim/Development/OCPPM/experiments/case_study/feature_encodings/hoeg/~/Development/OCPPM/.env/lib/python3.9/site-packages/torch/fx/graph_module.py:277), in _WrappedCall.__call__(self, obj, *args, **kwargs)
    275     raise e.with_traceback(None)
    276 else:
--> 277     raise e

File [~/Development/OCPPM/.env/lib/python3.9/site-packages/torch/fx/graph_module.py:267](https://file+.vscode-resource.vscode-cdn.net/home/tim/Development/OCPPM/experiments/case_study/feature_encodings/hoeg/~/Development/OCPPM/.env/lib/python3.9/site-packages/torch/fx/graph_module.py:267), in _WrappedCall.__call__(self, obj, *args, **kwargs)
    265         return self.cls_call(obj, *args, **kwargs)
    266     else:
--> 267         return super(self.cls, obj).__call__(*args, **kwargs)  # type: ignore[misc]
    268 except Exception as e:
    269     assert e.__traceback__

File [~/Development/OCPPM/.env/lib/python3.9/site-packages/torch/nn/modules/module.py:1194](https://file+.vscode-resource.vscode-cdn.net/home/tim/Development/OCPPM/experiments/case_study/feature_encodings/hoeg/~/Development/OCPPM/.env/lib/python3.9/site-packages/torch/nn/modules/module.py:1194), in Module._call_impl(self, *input, **kwargs)
   1190 # If we don't have any hooks, we want to skip the rest of the logic in
   1191 # this function, and just call forward.
   1192 if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
   1193         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1194     return forward_call(*input, **kwargs)
   1195 # Do not call functions when jit is used
   1196 full_backward_hooks, non_full_backward_hooks = [], []

File .19:20, in forward(self, x, edge_index, batch)
     18 convs_0__event1 = getattr(self.convs, "0").event__follows__event(x__event, edge_index__event__follows__event)
     19 convs_0__event2 = getattr(self.convs, "0").krs__interacts__event((x__krs, x__event), edge_index__krs__interacts__event)
---> 20 convs_0__event3 = getattr(self.convs, "0").krv__interacts__event((x__krv, x__event), edge_index__krv__interacts__event)
     21 convs_0__event4 = getattr(self.convs, "0").cv__interacts__event((x__cv, x__event), edge_index__cv__interacts__event);  x__event = None
     22 convs_0__krs = getattr(self.convs, "0").krs__updates__krs(x__krs, edge_index__krs__updates__krs);  x__krs = None

File [~/Development/OCPPM/.env/lib/python3.9/site-packages/torch/nn/modules/module.py:1194](https://file+.vscode-resource.vscode-cdn.net/home/tim/Development/OCPPM/experiments/case_study/feature_encodings/hoeg/~/Development/OCPPM/.env/lib/python3.9/site-packages/torch/nn/modules/module.py:1194), in Module._call_impl(self, *input, **kwargs)
   1190 # If we don't have any hooks, we want to skip the rest of the logic in
   1191 # this function, and just call forward.
   1192 if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
   1193         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1194     return forward_call(*input, **kwargs)
   1195 # Do not call functions when jit is used
   1196 full_backward_hooks, non_full_backward_hooks = [], []

File [~/Development/OCPPM/.env/lib/python3.9/site-packages/torch_geometric/nn/conv/graph_conv.py:86](https://file+.vscode-resource.vscode-cdn.net/home/tim/Development/OCPPM/experiments/case_study/feature_encodings/hoeg/~/Development/OCPPM/.env/lib/python3.9/site-packages/torch_geometric/nn/conv/graph_conv.py:86), in GraphConv.forward(self, x, edge_index, edge_weight, size)
     83     x: OptPairTensor = (x, x)
     85 # propagate_type: (x: OptPairTensor, edge_weight: OptTensor)
---> 86 out = self.propagate(edge_index, x=x, edge_weight=edge_weight,
     87                      size=size)
     88 out = self.lin_rel(out)
     90 x_r = x[1]

File [~/Development/OCPPM/.env/lib/python3.9/site-packages/torch_geometric/nn/conv/message_passing.py:459](https://file+.vscode-resource.vscode-cdn.net/home/tim/Development/OCPPM/experiments/case_study/feature_encodings/hoeg/~/Development/OCPPM/.env/lib/python3.9/site-packages/torch_geometric/nn/conv/message_passing.py:459), in MessagePassing.propagate(self, edge_index, size, **kwargs)
    456     for arg in decomp_args:
    457         kwargs[arg] = decomp_kwargs[arg][i]
--> 459 coll_dict = self._collect(self._user_args, edge_index, size,
    460                           kwargs)
    462 msg_kwargs = self.inspector.distribute('message', coll_dict)
    463 for hook in self._message_forward_pre_hooks.values():

File [~/Development/OCPPM/.env/lib/python3.9/site-packages/torch_geometric/nn/conv/message_passing.py:336](https://file+.vscode-resource.vscode-cdn.net/home/tim/Development/OCPPM/experiments/case_study/feature_encodings/hoeg/~/Development/OCPPM/.env/lib/python3.9/site-packages/torch_geometric/nn/conv/message_passing.py:336), in MessagePassing._collect(self, args, edge_index, size, kwargs)
    334         if isinstance(data, Tensor):
    335             self._set_size(size, dim, data)
--> 336             data = self._lift(data, edge_index, dim)
    338         out[arg] = data
    340 if is_torch_sparse_tensor(edge_index):

File [~/Development/OCPPM/.env/lib/python3.9/site-packages/torch_geometric/nn/conv/message_passing.py:275](https://file+.vscode-resource.vscode-cdn.net/home/tim/Development/OCPPM/experiments/case_study/feature_encodings/hoeg/~/Development/OCPPM/.env/lib/python3.9/site-packages/torch_geometric/nn/conv/message_passing.py:275), in MessagePassing._lift(self, src, edge_index, dim)
    273 except (IndexError, RuntimeError) as e:
    274     if index.min() < 0 or index.max() >= src.size(self.node_dim):
--> 275         raise IndexError(
    276             f"Encountered an index error. Please ensure that all "
    277             f"indices in 'edge_index' point to valid indices in "
    278             f"the interval [0, {src.size(self.node_dim) - 1}] "
    279             f"(got interval "
    280             f"[{int(index.min())}, {int(index.max())}])")
    281     else:
    282         raise e

IndexError: Encountered an index error. Please ensure that all indices in 'edge_index' point to valid indices in the interval [0, 15] (got interval [0, 25])

I don't understand why this is happening, as I've already run data.validate() on each graph in my dataset, and I've also checked some of the edge types' validity:

valid_batches = []
for loader in [train_loader,val_loader,test_loader]:
    for batch in loader:
        is_valid = batch['event', 'follows', 'event'].edge_index.max()+1 == batch['event'].x.shape[0]
        valid_batches.append(is_valid)
print('All batches valid: ', all(valid_batches))

Which gives: All batches valid: True

What could be going on here? And what could I do to fix it?

Any help would be very much appreciated ๐Ÿ™๐Ÿฝ ๐Ÿ˜€

Environment

rusty1s commented 1 year ago

Thanks for reporting. What happens if you change

self.conv1 = pygnn.GraphConv(-1, hidden_channels)

to

self.conv1 = pygnn.GraphConv((-1, -1), hidden_channels)

Otherwise, do you mind sharing a single data object with me?

TKForgeron commented 1 year ago

Hi Matthias,

Thanks for your reply. I did not see it back then. Currently, the issue is already fixed on my side.

Your suggestion I've tried, but it did not help.

How I solved it: I wrote some extra tests that went through the whole dataset checking whether the edge index values of each edge type were valid. I considered them valid, if the largest index in the edge_index was not greater than the number of nodes of the respective node type in the HeteroData object in question.

For those interested, this is the code I used for the validation:

def validate_cs_hoeg_dataset(dataset: HOEG, verbose: bool = True) -> list[HeteroData]:
    event_ids_ev_ev = []
    krs_ids_krs_ev = []
    krv_ids_krv_ev = []
    cv_ids_cv_ev = []
    event_ids_krs_ev = []
    event_ids_krv_ev = []
    event_ids_cv_ev = []
    krs_ids_krs_krs = []
    krv_ids_krv_krv = []
    cv_ids_cv_cv = []
    invalid_batches = []
    for batch in dataset:
        batch: HeteroData
        ev_ev_ev = (
            batch["event"].num_nodes
            == batch["event"].x.shape[0]
            == int(batch["event", "follows", "event"].edge_index.max() + 1)
        )
        krs_krs_ev = (
            batch["krs"].num_nodes
            == batch["krs"].x.shape[0]
            == int(batch["krs", "interacts", "event"].edge_index[0].max() + 1)
        )
        krv_krv_ev = (
            batch["krv"].num_nodes
            == batch["krv"].x.shape[0]
            == int(batch["krv", "interacts", "event"].edge_index[0].max() + 1)
        )
        cv_cv_ev = (
            batch["cv"].num_nodes
            == batch["cv"].x.shape[0]
            == int(batch["cv", "interacts", "event"].edge_index[0].max() + 1)
        )
        ev_krs_ev = batch["event"].num_nodes == batch["event"].x.shape[0] and batch[
            "event"
        ].num_nodes >= int(batch["krs", "interacts", "event"].edge_index[1].max() + 1)
        ev_krv_ev = batch["event"].num_nodes == batch["event"].x.shape[0] and batch[
            "event"
        ].num_nodes >= int(batch["krv", "interacts", "event"].edge_index[1].max() + 1)
        ev_cv_ev = batch["event"].num_nodes == batch["event"].x.shape[0] and batch[
            "event"
        ].num_nodes >= int(batch["cv", "interacts", "event"].edge_index[1].max() + 1)
        krs_krs_krs = (
            batch["krs"].x.shape[0]
            == batch["krs"].num_nodes
            == int(batch["krs", "updates", "krs"].edge_index.max() + 1)
        )
        krv_krv_krv = (
            batch["krv"].x.shape[0]
            == batch["krv"].num_nodes
            == int(batch["krv", "updates", "krv"].edge_index.max() + 1)
        )
        cv_cv_cv = (
            batch["cv"].x.shape[0]
            == batch["cv"].num_nodes
            == int(batch["cv", "updates", "cv"].edge_index.max() + 1)
        )

        event_ids_ev_ev.append(ev_ev_ev)
        krs_ids_krs_ev.append(krs_krs_ev)
        krv_ids_krv_ev.append(krv_krv_ev)
        cv_ids_cv_ev.append(cv_cv_ev)
        event_ids_krs_ev.append(ev_krs_ev)
        event_ids_krv_ev.append(ev_krv_ev)
        event_ids_cv_ev.append(ev_cv_ev)
        krs_ids_krs_krs.append(krs_krs_krs)
        krv_ids_krv_krv.append(krv_krv_krv)
        cv_ids_cv_cv.append(cv_cv_cv)
        if not all(
            [
                ev_ev_ev,
                krs_krs_ev,
                krv_krv_ev,
                cv_cv_ev,
                ev_krs_ev,
                ev_krv_ev,
                ev_cv_ev,
                krs_krs_krs,
                krv_krv_krv,
                cv_cv_cv,
            ]
        ):
            invalid_batches.append(batch)
    if verbose:
        print(
            "Event node indices valid in all HeteroData for edge type event-event: ",
            all(event_ids_ev_ev),
        )
        print(
            "KRS node indices valid in all HeteroData for edge type krs-event: ",
            all(krs_ids_krs_ev),
        )
        print(
            "KRV node indices valid in all HeteroData for edge type krv-event: ",
            all(krv_ids_krv_ev),
        )
        print(
            "CV node indices valid in all HeteroData for edge type cv-event: ",
            all(cv_ids_cv_ev),
        )
        print(
            "Event node indices valid in all HeteroData for edge type krs-event: ",
            all(event_ids_krs_ev),
        )
        print(
            "Event node indices valid in all HeteroData for edge type krv-event: ",
            all(event_ids_krv_ev),
        )
        print(
            "Event node indices valid in all HeteroData for edge type cv-event: ",
            all(event_ids_cv_ev),
        )
        print(
            "KRS node indices valid in all HeteroData for edge type krs-krs: ",
            all(krs_ids_krs_krs),
        )
        print(
            "KRV node indices valid in all HeteroData for edge type krv-krv: ",
            all(krv_ids_krv_krv),
        )
        print(
            "CV node indices valid in all HeteroData for edge type cv-cv: ",
            all(cv_ids_cv_cv),
        )
        print()
    print("HOEG dataset valid: ", not (len(invalid_batches)))
    return invalid_batches

Some HeteroData objects did have edge_index values pointing to node indices larger than the total number of nodes in the graph. I ended up solving this by replacing the invalid edge_index values by 0, which was always valid in my case.

Thanks again!

PS. The interesting thing here was: the .validate() method did not find the issue. Maybe this is interesting to include in the validation script. @rusty1s

rusty1s commented 1 year ago

Cool that you have solved it. Can you clarify your comment on validate() though? Checking for out-of-bound indices in edge_index for every edge type already happens within HeteroData.validate() AFAIK.

TKForgeron commented 11 months ago

All I know is that validate() did not raise a warning or error and returned True. Maybe validate() does not check all edge types for a HeteroData object?