FluxML / GeometricFlux.jl

Geometric Deep Learning for Flux
https://fluxml.ai/GeometricFlux.jl/stable/
MIT License
348 stars 30 forks source link

gradient for GATConv layer is very slow #243

Closed afternone closed 2 years ago

afternone commented 2 years ago
@show gradient(X -> loss(X, train_y), train_X) # line 45 in 'example/gat.jl'

This line takes more than 200s on CPU, but the forward pass takes less than 1s on the same CPU.

yuehhua commented 2 years ago

Computing gradient for the whole graph could take longer time. It is suggested to compute less than 64 nodes for gradient for each updates. A proper node number to update would be 32. I will release a new feature for node sampling to act like batch size in classical deep learning model. It will update gradient for some nodes in a graph randomly.

yuehhua commented 2 years ago

Check new version here. If you have any questions, it's welcome to reopen the issue.