Closed Breaddsmall closed 1 year ago
I found the problem is caused by the last node of a batch not connected to any other nodes in the batch, which means its index isn't in the adjacency matrix. Therefore, during the process of torch.scatter(), it is not aggregated, leading to the mismatch in the dimension.
Hi @Breaddsmall, thanks for your interest in our work, and sorry for my late reply! Yes, isolated nodes may cause some problems. You can try to remove isolated nodes or add self-loop to solve this issue.
In addition, in our implementation, for the scatter function, we only give the src
, index
, and dim
as inputs, you can also try to add additional inputs like out
and dim_size
. This can also solve this issue.
Best
Hello, I am not sure how to use the above solution without changing the original dimensions. Can you provide some examples?
Hi @pearl-rabbit,
About the scatter function I mentioned reviously, could you try to replace line 203-205 with
def forward(self, e, i, num_nodes):
_, e2 = e
v = scatter(e2, i, dim=0, dim_size=num_nodes)
and change line 301 to v = update_v(e, i, num_nodes)
to see if this can solve your problem?
I tried to use spherenet and schnet on my customized material datasets, and I met these problems during my training.
Does this kind of problem occurs because of some of the nodes are not connected with others by the given cutoff? If so, how should I fix this (other than setting a larger cutoff)