divelab / DIG

A library for graph deep learning research
https://diveintographs.readthedocs.io/
GNU General Public License v3.0
1.84k stars 281 forks source link

PGExplainer optimization #117

Closed alirezadizaji closed 2 years ago

alirezadizaji commented 2 years ago

Hi, I read pseudo code of paper and it seems the loss calculation and backpropagation should be done after getting prediction probabilities for all graph inputs (graph explanation) or nodes (node explanation), though current code calculates loss and backpropagates it per sample

Oceanusity commented 2 years ago

Hello, I think the optimization step uses gradient descent (GD). In this way, it collects the gradients for each batch and updates the model weights when all the graph inputs are included.

As shown here, it uses the code loss.backward() for each batch. However, we only update the model weights with optimizer.step() when the gradients from all the graphs are collected. In this case, we can save the GPU memory to train PGExplainer over datasets with thousands of graphs.

Oceanusity commented 2 years ago

This operation should perform the same as the full-batch training with all the examples as one batch to train the model weights.

Welcome to reply if you have further questions.