google-deepmind / spectral_inference_networks

Implementation of Spectral Inference Networks, ICLR 2019
Apache License 2.0
170 stars 30 forks source link

Jax implementation and stability of training #3

Open Binbose opened 2 years ago

Binbose commented 2 years ago

Hey, we implemented the algorithm in jax here: https://github.com/Binbose/SpIN-Jax

We also made a colab notebook with some extra visualizations like animated training and phase diagrams here

Generally, the algorithm is really cool and runs well. We can recover the first 4 eigenfunctions of hydrogen reliably. However, we noticed that the variance of the energies of our implementation (and also your TensorFlow implementation here, which virtually behaves identically to our implementation) is much bigger compared to the graph in your paper. We are not quite sure how you achieved such stable training, do you remember if you used some additional tricks? (The variance in the colab notebook is divided by 15 to make the graphs readable) This is also in accordance with the case of 9 eigenfunctions, for which the training generally behaves a lot less stable than what we can see in the paper and the eigenfunctions look somewhat recognizable but are still far off.

Things we tried: Reduce the learning rate up to 1e-6 Up the batch size to 512 Increase and decrease beta Play around with different 'sparsifying' K (we couldn't find the parameter you used in your work) Tried different decay rates for rmsprop and adam

Can you offer some additional insights?

dpfau commented 2 years ago

Hi Luca,

This is great - thanks for your interest! Unfortunately, it's been about 4 years since I really touched this, so I don't recall the exact parameters we used. I do remember we had to tweak the parameters very carefully to get it to work, and would have to train for up to 1m iterations to really converge. It's also possible we smoothed the variance in the figures substantially to make it more readable.

If you are interested in pursuing this further, I'd reach out to Paris Perdikaris at Penn. His group has also re-implemented SpIN in JAX and can probably give you more up-to-date advice on how to get it to train effectively. The colab notebook is here: https://colab.sandbox.google.com/drive/1x4AAYfLro5ckjZvXrvccBcwIDf55ev4r?usp=sharing

Best of luck! David

On Fri, Apr 1, 2022 at 5:30 AM Luca Thiede @.***> wrote:

Hey, we implemented the algorithm in jax here: https://github.com/Binbose/SpIN-Jax

We also made a colab notebook with some extra visualizations like animated training and phase diagrams here https://colab.research.google.com/drive/1hRm3zbf8ptJ00dGKKTohtBL3WNIg7tEl?usp=sharing#scrollTo=0PiLKO_bQjvp

Generally, the algorithm is really cool and runs well. We can recover the first 4 eigenfunctions of hydrogen reliably. However, we noticed that the variance of the energies of our implementation (and also your TensorFlow implementation here, which virtually behaves identically to our implementation) is much bigger compared to the graph in your paper. We are not quite sure how you achieved such stable training, do you remember if you used some additional tricks? (The variance in the colab notebook is divided by 15 to make the graphs readable) This is also in accordance with the case of 9 eigenfunctions, for which the training generally behaves a lot less stable than what we can see in the paper and the eigenfunctions look somewhat recognizable but are still far off.

Things we tried: Reduce the learning rate up to 1e-6 Up the batch size to 512 Increase and decrease beta Play around with different 'sparsifying' K (we couldn't find the parameter you used in your work) Tried different decay rates for rmsprop and adam

Can you offer some additional insights?

— Reply to this email directly, view it on GitHub https://github.com/deepmind/spectral_inference_networks/issues/3, or unsubscribe https://github.com/notifications/unsubscribe-auth/AABDACBJHKMLPBRSVVL4TOTVCZ3WDANCNFSM5SHPEQKA . You are receiving this because you are subscribed to this thread.Message ID: @.***>