FenTechSolutions / CausalDiscoveryToolbox

Package for causal inference in graphs and in the pairwise settings. Tools for graph structure recovery and dependencies are included.
https://fentechsolutions.github.io/CausalDiscoveryToolbox/html/index.html
MIT License
1.08k stars 198 forks source link

Potential Bug for GNN when computing the loss function #62

Closed sAviOr287 closed 4 years ago

sAviOr287 commented 4 years ago

Hi,

I just read through the CGNN code, mainly interested in the pairwise version.

It looks like the criterion in computing the MMD(y, y_pred) https://github.com/FenTechSolutions/CausalDiscoveryToolbox/blob/32200779ab9b63762be3a24a2147cff09ba2bb72/cdt/causality/pairwise/GNN.py#L111

However, in the original paper the compute the MMD([x, y], [x, y_pred]) https://github.com/GoudetOlivier/CGNN/blob/e3fcfc570e30fb8dad8bf00f619ef3c21998bb90/Code/cgnn/GNN.py#L70

Thanks a lot for the repo and the reply. Helped me understand a lot of new things.

diviyank commented 4 years ago

Dear @sAviOr287 ,

Thanks for pointing that out: we will fix this point on the next update !

Best regards, Diviyan

sAviOr287 commented 4 years ago

Thanks,

I was also wondering about the run times The paper quoted 24min on GPU. But when I tried out your code it seems to take at least 10+ hours with 6 GPUs for the Multi dataset for example as it has 300 datasets each can potentially has to be trained 32 x2 times I was wondering if you could clarify the number on the CGNN paper or whether i am doing sth wrong

Thanks for the quick reply.

diviyank commented 4 years ago

I thought I fixed the performance issue ; is your package up to date ? It's the GNN right (and not the CGNN) ? Could you check that the cdt uses the 6 GPUs ? (via nvidia-smi) (one GNN is executed on one GPU, it's when you make multiple runs that all the GPUs are used )

Depending on the memory consumption, you could squeeze multiple GNN on one GPU. how many examples do you have ? (what is the size of your datasets ? )

sAviOr287 commented 4 years ago

yeah, I am using GNN,

I just cloned it 1 week ago. Yeah when I start the code it says I have 8gpus but I only use 6 to run and using nvidia-smi i can see that all 6 are being used I am just running it on the Multi/Gauss/Net datasets which have every 300 datasets each 1500 points.

where nb_max_runs=32, n_runs=6, train_epochs=1000, dataloader_workers=0 So is the 24 mins if I max out every single GPU ie around 32 models at once?

diviyank commented 4 years ago

Ok great,

The 24 Mins were obtained for a single dataset (of 500 points) maxing out one GPU (Kepler arch.), for its 32x2 runs.

I noticed that there is no specific code for dataset prediction, and the memory consumption might be suboptimal, but this is a minor issue. (Keep in mind to run the pairs one by one)

However, something is strange here: 7728c3d should have removed this nb_max_runs argument ; it should be deprecated... This comes from an old version of the code; could you execute ? :

print(cdt.__version__)

Otherwise, you could get some more performance my monitoring the GPU compute consumption (The % on the right in nvidia-smi) and the memory consumption. setting n_jobs to 6*k and n_runs=32, it's a nice trick to cut down computation time.

Best,

sAviOr287 commented 4 years ago

OH yeah are were right somehow didn't get the latest version. Could you tell me what you changed besides removing the Testing bit training 1000 epochs is a lot faster somehow now.

I am also curious how people still go results with the wrong loss function lol

Thanks a lot for your help btw.

diviyank commented 4 years ago

The dataset management is done differently: the PyTorch Dataset feature might be a bit tricky to use, so i simplified it a lot and used the proposed TensorDataset class that eases the use of this feature. (thus cutting down computation time for most of the cases)

There were maybe some (unconfirmed) performance issues on GNN (#44), could you check if the performance is still good ? (values around the ones from the paper) ; I was afraid that I broke something during the refactoring.

diviyank commented 4 years ago

I'll be closing this issue, don't hesitate to open a new issue if a performance issue on GNN arises. Fingers crossed that it was the bug that you noticed !