Open TKForgeron opened 1 year ago
Thanks for reporting. What happens if you change
self.conv1 = pygnn.GraphConv(-1, hidden_channels)
to
self.conv1 = pygnn.GraphConv((-1, -1), hidden_channels)
Otherwise, do you mind sharing a single data
object with me?
Hi Matthias,
Thanks for your reply. I did not see it back then. Currently, the issue is already fixed on my side.
Your suggestion I've tried, but it did not help.
How I solved it: I wrote some extra tests that went through the whole dataset checking whether the edge index values of each edge type were valid. I considered them valid, if the largest index in the edge_index was not greater than the number of nodes of the respective node type in the HeteroData object in question.
For those interested, this is the code I used for the validation:
def validate_cs_hoeg_dataset(dataset: HOEG, verbose: bool = True) -> list[HeteroData]:
event_ids_ev_ev = []
krs_ids_krs_ev = []
krv_ids_krv_ev = []
cv_ids_cv_ev = []
event_ids_krs_ev = []
event_ids_krv_ev = []
event_ids_cv_ev = []
krs_ids_krs_krs = []
krv_ids_krv_krv = []
cv_ids_cv_cv = []
invalid_batches = []
for batch in dataset:
batch: HeteroData
ev_ev_ev = (
batch["event"].num_nodes
== batch["event"].x.shape[0]
== int(batch["event", "follows", "event"].edge_index.max() + 1)
)
krs_krs_ev = (
batch["krs"].num_nodes
== batch["krs"].x.shape[0]
== int(batch["krs", "interacts", "event"].edge_index[0].max() + 1)
)
krv_krv_ev = (
batch["krv"].num_nodes
== batch["krv"].x.shape[0]
== int(batch["krv", "interacts", "event"].edge_index[0].max() + 1)
)
cv_cv_ev = (
batch["cv"].num_nodes
== batch["cv"].x.shape[0]
== int(batch["cv", "interacts", "event"].edge_index[0].max() + 1)
)
ev_krs_ev = batch["event"].num_nodes == batch["event"].x.shape[0] and batch[
"event"
].num_nodes >= int(batch["krs", "interacts", "event"].edge_index[1].max() + 1)
ev_krv_ev = batch["event"].num_nodes == batch["event"].x.shape[0] and batch[
"event"
].num_nodes >= int(batch["krv", "interacts", "event"].edge_index[1].max() + 1)
ev_cv_ev = batch["event"].num_nodes == batch["event"].x.shape[0] and batch[
"event"
].num_nodes >= int(batch["cv", "interacts", "event"].edge_index[1].max() + 1)
krs_krs_krs = (
batch["krs"].x.shape[0]
== batch["krs"].num_nodes
== int(batch["krs", "updates", "krs"].edge_index.max() + 1)
)
krv_krv_krv = (
batch["krv"].x.shape[0]
== batch["krv"].num_nodes
== int(batch["krv", "updates", "krv"].edge_index.max() + 1)
)
cv_cv_cv = (
batch["cv"].x.shape[0]
== batch["cv"].num_nodes
== int(batch["cv", "updates", "cv"].edge_index.max() + 1)
)
event_ids_ev_ev.append(ev_ev_ev)
krs_ids_krs_ev.append(krs_krs_ev)
krv_ids_krv_ev.append(krv_krv_ev)
cv_ids_cv_ev.append(cv_cv_ev)
event_ids_krs_ev.append(ev_krs_ev)
event_ids_krv_ev.append(ev_krv_ev)
event_ids_cv_ev.append(ev_cv_ev)
krs_ids_krs_krs.append(krs_krs_krs)
krv_ids_krv_krv.append(krv_krv_krv)
cv_ids_cv_cv.append(cv_cv_cv)
if not all(
[
ev_ev_ev,
krs_krs_ev,
krv_krv_ev,
cv_cv_ev,
ev_krs_ev,
ev_krv_ev,
ev_cv_ev,
krs_krs_krs,
krv_krv_krv,
cv_cv_cv,
]
):
invalid_batches.append(batch)
if verbose:
print(
"Event node indices valid in all HeteroData for edge type event-event: ",
all(event_ids_ev_ev),
)
print(
"KRS node indices valid in all HeteroData for edge type krs-event: ",
all(krs_ids_krs_ev),
)
print(
"KRV node indices valid in all HeteroData for edge type krv-event: ",
all(krv_ids_krv_ev),
)
print(
"CV node indices valid in all HeteroData for edge type cv-event: ",
all(cv_ids_cv_ev),
)
print(
"Event node indices valid in all HeteroData for edge type krs-event: ",
all(event_ids_krs_ev),
)
print(
"Event node indices valid in all HeteroData for edge type krv-event: ",
all(event_ids_krv_ev),
)
print(
"Event node indices valid in all HeteroData for edge type cv-event: ",
all(event_ids_cv_ev),
)
print(
"KRS node indices valid in all HeteroData for edge type krs-krs: ",
all(krs_ids_krs_krs),
)
print(
"KRV node indices valid in all HeteroData for edge type krv-krv: ",
all(krv_ids_krv_krv),
)
print(
"CV node indices valid in all HeteroData for edge type cv-cv: ",
all(cv_ids_cv_cv),
)
print()
print("HOEG dataset valid: ", not (len(invalid_batches)))
return invalid_batches
Some HeteroData objects did have edge_index values pointing to node indices larger than the total number of nodes in the graph. I ended up solving this by replacing the invalid edge_index values by 0, which was always valid in my case.
Thanks again!
PS. The interesting thing here was: the .validate()
method did not find the issue. Maybe this is interesting to include in the validation script. @rusty1s
Cool that you have solved it. Can you clarify your comment on validate()
though? Checking for out-of-bound indices in edge_index
for every edge type already happens within HeteroData.validate()
AFAIK.
All I know is that validate()
did not raise a warning or error and returned True
. Maybe validate()
does not check all edge types for a HeteroData
object?
๐ Describe the bug
I made my own custom dataset, it consists of HeteroData graphs. This is what one such graph looks like:
I use the standard Dataloader to create batches of size 16.
I then train using this model:
Which I converted to a hetero one:
I then got this lovely error:
I don't understand why this is happening, as I've already run
data.validate()
on each graph in my dataset, and I've also checked some of the edge types' validity:Which gives:
All batches valid: True
What could be going on here? And what could I do to fix it?
Any help would be very much appreciated ๐๐ฝ ๐
Environment
conda
,pip
, source): piptorch-scatter
):