google-deepmind / ferminet

An implementation of the Fermionic Neural Network for ab-initio electronic structure calculations
Apache License 2.0
721 stars 120 forks source link

how to reconstruct a ferminet? #6

Closed ley61 closed 3 years ago

ley61 commented 3 years ago

Dear authors, sorry to ask a naive question again.....

I am quite interest in the network and trying to use it to study interesting physics phenomenon. However, I am a new hand in tensorflow, and I encounter some problems about how to reconstruct this ferminet.

I mean: for example, after I run the H2 example in the Usage. I think this program then restores the trained parameters in the checkpoints. But I don't know how to use the file in the checkpoints to reconstruct the ferminet I have trained. Then I can input the coordinates of the electrons and return the corresponding wave function value.

I think it's somehow an annoying question and sorry to bother.

dpfau commented 3 years ago

Not annoying at all, we're happy to help! This is one of the things that's a little trickier to do in TensorFlow, but more straightforward in JAX. You might want to try out the JAX branch of the code, especially if you're just looking at small systems like H2. Then, the parameters are just stored as a dictionary of NumPy arrays. If you want to train bigger systems, you may still want to use TensorFlow, because that has the KFAC optimizer available. But hopefully this is enough to get you started.

ley61 commented 3 years ago

Thanks, I will try the jax branch and study more about tf~~

jsspencer commented 3 years ago

+100 to playing with the JAX version. Unfortunately we don't yet have a KFAC implementation for JAX so it's limited but its much easier to inspect the checkpoints and reason them.

For the TensorFlow 1 version, its worth carefully reading the code and following also tutorials and docuentation on the TensorFlow website, in particular on checkpointing and MonitoredTrainingSession. One option would be to fork train.py and qmc.py (which have support for loading checkpoints already) and modify the "training" loop to take in the electron coordinates of your choosing. But, it is useful to play with this interactively.

Some ways of loading checkpoints:

(Please note in all these examples I have not checked the batch size, number of iterations, MCMC steps, MCMC step size, learning rate etc for convergence/optimal values. I just picked values that ran quickly for the purpose of creating an example.)

  1. Loading checkpoints from a previous calculation.

First, run your calculation.

ferminet --system H2 --batch_size 256 --pretrain_iterations 100 --iterations 1000 

ferminet will create a time-stamped folder (by default under the working directory, which can be changed using the --results_folder flag), e.g. in my case ferminet_results_Wed_Dec__9_08\:37\:23_2020. This directory contains a checkpoint directory. Pass this to the --restore_path flag to restore the latest checkpoints stored in this directory. You want to disable pretraining and (typically) the MCMC burn-in here!

ferminet --system H2 --batch_size 256 --pretrain_iterations 0 --mcmc_burn_in 0 --iterations 1000 --result_folder H2_inference --restore_path ferminet_results_Wed_Dec__9_08\:37\:23_2020/checkpoints/

You should adjust the path passed to restore_path to match what your checkpoint directory.

Note setting --learning_rate 0 disables optimisation. This is useful for restoring a checkpoint and performing MCMC to evaluate the energy on a fixed network. We refer to this as inference in the FermiNet paper.

  1. Restoring a checkpoint interactively. Note it's vital to use the same network settings when restoring the checkpoint and construct the network in the same scope as creating in train.train:
import numpy as np
import tensorflow.compat.v1 as tf
from ferminet import networks
from ferminet import train
from ferminet.utils import system

# Ensure you use the same system and geometry as in the original calculation!!
molecule = [system.Atom('H', (0, 0, -1)), system.Atom('H', (0, 0, 1))]
spins = (1, 1)

# Important! Must use the same network settings as used in the original calculation. Adjust these to match the original. If you didn't change anything, NetworkConfig() will match the original calculation.
network_config = train.NetworkConfig()

# Build the nework in the same scope as original. See train.train.
# (This is a little complicated than necessary so we can also checkpoint the MCMC state)
with tf.variable_scope('model') as model:
  pass
with tf.variable_scope(model, auxiliary_name_scope=False) as model1:
  with tf.name_scope(model1.original_name_scope):
    fermi_net = networks.FermiNet(
    atoms=molecule,
    nelectrons=spins,
    slater_dets=network_config.determinants,
    hidden_units=network_config.hidden_units,
    after_det=network_config.after_det,
    architecture=network_config.architecture,
    r12_ee_features=network_config.r12_ee_features,
    r12_en_features=network_config.r12_en_features,
    pos_ee_features=network_config.pos_ee_features,
    build_backflow=network_config.build_backflow,
    use_backflow=network_config.backflow,
    jastrow_en=network_config.jastrow_en,
    jastrow_ee=network_config.jastrow_ee,
    jastrow_een=network_config.jastrow_een,
    logdet=True,
    envelope=network_config.use_envelope,
    residual=network_config.residual,
    pretrain_iterations=0)

# Create your input data - should be a 2D array of shape (B, 3*N) where N is the number of electrons and B is an arbitrary number of points. You can do this directly in tensorflow, or in numpy and then use tf.constant, or using placeholders (see TF documentation).
xs = np.random.uniform(size=(3,6))
x = tf.constant(xs, dtype=tf.float32)  # CARE: FermiNet works in single precision by default.

# Feed the points into the network.
psi = fermi_net(x)

# Create a session which uses the checkpoint previously created. This is easiest done using MonitoredTrainingSession
with tf.train.MonitoredTrainingSession(
    checkpoint_dir='H2/ferminet_results_Wed_Dec__9_08\:37\:23_2020/checkpoints',
    save_checkpoint_steps=None,  # we haven't created a global step in this example, so disable automatic checkpointing every Y steps
    save_summaries_steps=None,
    log_step_count_steps=None) as session:
  psi_1 = session.run(psi)

If this is successful, you should (you might need to adjust the log level settings -- see TF docs) see something like:

INFO:tensorflow:Graph was finalized.
INFO:tensorflow:Restoring parameters from H2/ferminet_results_Wed_Dec__9_08\:37\:23_2020/checkpoints/model.ckpt-207
INFO:tensorflow:Running local_init_op.
INFO:tensorflow:Done running local_init_op.

If the checkpoiint is not restored, only the first line and last two lines will be printed.

Note that the above code becomes a fair amount more complicated if you are using multiple GPUs. Please refer to the TensorFlow DistributionStrategy documentation in this case -- the code in train.py and qmc.py for model construction and the training loop should be a good example of the steps required to set this up.

  1. Another option is to simply return the network at the end of train.train. :)
ley61 commented 3 years ago

Thanks for your patience sincerely. I think I can enjoy more about this program now!

connection-on-fiber-bundles commented 3 years ago

Hey thanks for the wonderful JAX implementation. Just wondering if it's the code directly used in the experiments mentioned in the paper Better, Faster Fermionic Neural Networks? If so, then does it mean that those experiments are also using ADAM instead of KFAC as optimizer (the paper didn't mention explicitly which optimizer it used in the experiments)? Also is it challenging to implement KFAC with JAX and why?

dpfau commented 3 years ago

The JAX version is most of the code from the paper you linked. The training setup in that paper used the same KFAC parameters as in the original FermiNet paper. The JAX implementation of KFAC is a separate package that has not yet been cleared for open sourcing. Implementing KFAC in JAX is challenging - certainly more challenging than, say, ADAM. You have to add hooks that allow you to get out the intermediate activations and backward errors from intermediate layers, and also different layers have to be registered in different ways. You can try implementing it yourself if you like, but it may be easier to wait for it to be open-sourced.

On Mon, Dec 28, 2020 at 9:48 AM connection-on-fiber-bundles < notifications@github.com> wrote:

Hey thanks for the wonderful JAX implementation. Just wondering if it's the code directly used in the experiments mentioned in the paper Better, Faster Fermionic Neural Networks https://arxiv.org/abs/2011.07125? If so, then does it mean that those experiments are also using ADAM instead of KFAC as optimizer (the paper didn't mention explicitly which optimizer it used in the experiments)? Also is it challenging to implement KFAC with JAX and why?

— You are receiving this because you commented. Reply to this email directly, view it on GitHub https://github.com/deepmind/ferminet/issues/6#issuecomment-751653912, or unsubscribe https://github.com/notifications/unsubscribe-auth/AABDACEO5UJNUFF2QIVTC2LSXBH6JANCNFSM4URXUFJQ .

connection-on-fiber-bundles commented 3 years ago

@dpfau Cool got it. Thanks a lot!

connection-on-fiber-bundles commented 3 years ago

The JAX version is most of the code from the paper you linked. The training setup in that paper used the same KFAC parameters as in the original FermiNet paper. The JAX implementation of KFAC is a separate package that has not yet been cleared for open sourcing. Implementing KFAC in JAX is challenging - certainly more challenging than, say, ADAM. You have to add hooks that allow you to get out the intermediate activations and backward errors from intermediate layers, and also different layers have to be registered in different ways. You can try implementing it yourself if you like, but it may be easier to wait for it to be open-sourced. On Mon, Dec 28, 2020 at 9:48 AM connection-on-fiber-bundles < @.***> wrote: Hey thanks for the wonderful JAX implementation. Just wondering if it's the code directly used in the experiments mentioned in the paper Better, Faster Fermionic Neural Networks https://arxiv.org/abs/2011.07125? If so, then does it mean that those experiments are also using ADAM instead of KFAC as optimizer (the paper didn't mention explicitly which optimizer it used in the experiments)? Also is it challenging to implement KFAC with JAX and why? — You are receiving this because you commented. Reply to this email directly, view it on GitHub <#6 (comment)>, or unsubscribe https://github.com/notifications/unsubscribe-auth/AABDACEO5UJNUFF2QIVTC2LSXBH6JANCNFSM4URXUFJQ .

Hey @dpfau , any update on open-sourcing JAX-version KFAC? I have seen Roger Grosse posting a JAX implementation of KFAC in his course material https://github.com/rgrosse/csc2541_examples for pedagogical purpose, but not sure how usable or performant it is.

If JAX KFAC is not going to be open-sourced any time soon, I will probably try to implement it myself. As you mentioned, I need to "add hooks that allow you to get out the intermediate activations and backward errors from intermediate layers". Just want to double check, does JAX provide hook mechanism like pytorch does (register_forward_pre_hook and register_backward_hook in torch.nn.Module)? It seems Roger Grosse did something quite tricky in his implementation to get those info (He did explain this trick in his lecture notes though).

dpfau commented 3 years ago

We are working on getting it out - I can't say for sure whether it will happen in a week or in a month. I am however pretty confident that we can get it released before you finish your own implementation from scratch. I doubt that Roger's implementation will reach the same accuracy we achieved

On Sun, Mar 14, 2021 at 8:32 AM connection-on-fiber-bundles < @.***> wrote:

The JAX version is most of the code from the paper you linked. The training setup in that paper used the same KFAC parameters as in the original FermiNet paper. The JAX implementation of KFAC is a separate package that has not yet been cleared for open sourcing. Implementing KFAC in JAX is challenging - certainly more challenging than, say, ADAM. You have to add hooks that allow you to get out the intermediate activations and backward errors from intermediate layers, and also different layers have to be registered in different ways. You can try implementing it yourself if you like, but it may be easier to wait for it to be open-sourced. … <#m-3434283084074404044> On Mon, Dec 28, 2020 at 9:48 AM connection-on-fiber-bundles < @.***> wrote: Hey thanks for the wonderful JAX implementation. Just wondering if it's the code directly used in the experiments mentioned in the paper Better, Faster Fermionic Neural Networks https://arxiv.org/abs/2011.07125? If so, then does it mean that those experiments are also using ADAM instead of KFAC as optimizer (the paper didn't mention explicitly which optimizer it used in the experiments)? Also is it challenging to implement KFAC with JAX and why? — You are receiving this because you commented. Reply to this email directly, view it on GitHub <#6 (comment) https://github.com/deepmind/ferminet/issues/6#issuecomment-751653912>, or unsubscribe https://github.com/notifications/unsubscribe-auth/AABDACEO5UJNUFF2QIVTC2LSXBH6JANCNFSM4URXUFJQ .

Hey @dpfau https://github.com/dpfau , any update on open-sourcing JAX-version KFAC? I have seen Roger Grosse posting a JAX implementation of KFAC in his course material https://github.com/rgrosse/csc2541_examples for pedagogical purpose, but not sure how usable or performant it is.

If JAX KFAC is not going to be open-sourced any time soon, I will probably try to implement it myself. As you mentioned, I need to "add hooks that allow you to get out the intermediate activations and backward errors from intermediate layers". Just want to double check, does JAX provide hook mechanism like pytorch does (register_forward_pre_hook and register_backward_hook in torch.nn.Module)? It seems Roger Grosse did something quite tricky in his implementation to get those info (He did explain this trick in his lecture notes though).

— You are receiving this because you were mentioned. Reply to this email directly, view it on GitHub https://github.com/deepmind/ferminet/issues/6#issuecomment-798869154, or unsubscribe https://github.com/notifications/unsubscribe-auth/AABDACCMFII2MDGWUXQLUL3TDRYA5ANCNFSM4URXUFJQ .

connection-on-fiber-bundles commented 3 years ago

Got it. Thanks a lot for the detailed explanation, and I'm looking forward to the KFAC optimizer!

jacobjinkelly commented 3 years ago

Hello @dpfau! I'm just following up to see if you knew of any updates about when the KFAC optimizer in JAX might be released. Thanks!

dpfau commented 3 years ago

Hi Jacob,

It was released several months ago: https://github.com/deepmind/deepmind-research/tree/master/kfac_ferminet_alpha The JAX branch is up-to-date and integrated with KFAC: https://github.com/deepmind/ferminet/tree/jax

Best, David

On Fri, Jun 11, 2021 at 1:11 AM Jacob Kelly @.***> wrote:

Hello @dpfau https://github.com/dpfau! I'm just following up to see if you knew of any updates about when the KFAC optimizer in JAX might be released. Thanks!

— You are receiving this because you were mentioned. Reply to this email directly, view it on GitHub https://github.com/deepmind/ferminet/issues/6#issuecomment-859167882, or unsubscribe https://github.com/notifications/unsubscribe-auth/AABDACDXCNV66LHM5DGZEF3TSFIELANCNFSM4URXUFJQ .

zhouyz08 commented 2 years ago

Trying to reconstruct either and this post helps me.

But... The way of setting the flags has changed. And look into the base_config for the correct way of referring a flag.

+100 to playing with the JAX version. Unfortunately we don't yet have a KFAC implementation for JAX so it's limited but its much easier to inspect the checkpoints and reason them.

For the TensorFlow 1 version, its worth carefully reading the code and following also tutorials and docuentation on the TensorFlow website, in particular on checkpointing and MonitoredTrainingSession. One option would be to fork train.py and qmc.py (which have support for loading checkpoints already) and modify the "training" loop to take in the electron coordinates of your choosing. But, it is useful to play with this interactively.

Some ways of loading checkpoints:

(Please note in all these examples I have not checked the batch size, number of iterations, MCMC steps, MCMC step size, learning rate etc for convergence/optimal values. I just picked values that ran quickly for the purpose of creating an example.)

  1. Loading checkpoints from a previous calculation.

First, run your calculation.

ferminet --system H2 --batch_size 256 --pretrain_iterations 100 --iterations 1000 

ferminet will create a time-stamped folder (by default under the working directory, which can be changed using the --results_folder flag), e.g. in my case ferminet_results_Wed_Dec__9_08\:37\:23_2020. This directory contains a checkpoint directory. Pass this to the --restore_path flag to restore the latest checkpoints stored in this directory. You want to disable pretraining and (typically) the MCMC burn-in here!

ferminet --system H2 --batch_size 256 --pretrain_iterations 0 --mcmc_burn_in 0 --iterations 1000 --result_folder H2_inference --restore_path ferminet_results_Wed_Dec__9_08\:37\:23_2020/checkpoints/

You should adjust the path passed to restore_path to match what your checkpoint directory.

Note setting --learning_rate 0 disables optimisation. This is useful for restoring a checkpoint and performing MCMC to evaluate the energy on a fixed network. We refer to this as inference in the FermiNet paper.

  1. Restoring a checkpoint interactively. Note it's vital to use the same network settings when restoring the checkpoint and construct the network in the same scope as creating in train.train:
import numpy as np
import tensorflow.compat.v1 as tf
from ferminet import networks
from ferminet import train
from ferminet.utils import system

# Ensure you use the same system and geometry as in the original calculation!!
molecule = [system.Atom('H', (0, 0, -1)), system.Atom('H', (0, 0, 1))]
spins = (1, 1)

# Important! Must use the same network settings as used in the original calculation. Adjust these to match the original. If you didn't change anything, NetworkConfig() will match the original calculation.
network_config = train.NetworkConfig()

# Build the nework in the same scope as original. See train.train.
# (This is a little complicated than necessary so we can also checkpoint the MCMC state)
with tf.variable_scope('model') as model:
  pass
with tf.variable_scope(model, auxiliary_name_scope=False) as model1:
  with tf.name_scope(model1.original_name_scope):
    fermi_net = networks.FermiNet(
    atoms=molecule,
    nelectrons=spins,
    slater_dets=network_config.determinants,
    hidden_units=network_config.hidden_units,
    after_det=network_config.after_det,
    architecture=network_config.architecture,
    r12_ee_features=network_config.r12_ee_features,
    r12_en_features=network_config.r12_en_features,
    pos_ee_features=network_config.pos_ee_features,
    build_backflow=network_config.build_backflow,
    use_backflow=network_config.backflow,
    jastrow_en=network_config.jastrow_en,
    jastrow_ee=network_config.jastrow_ee,
    jastrow_een=network_config.jastrow_een,
    logdet=True,
    envelope=network_config.use_envelope,
    residual=network_config.residual,
    pretrain_iterations=0)

# Create your input data - should be a 2D array of shape (B, 3*N) where N is the number of electrons and B is an arbitrary number of points. You can do this directly in tensorflow, or in numpy and then use tf.constant, or using placeholders (see TF documentation).
xs = np.random.uniform(size=(3,6))
x = tf.constant(xs, dtype=tf.float32)  # CARE: FermiNet works in single precision by default.

# Feed the points into the network.
psi = fermi_net(x)

# Create a session which uses the checkpoint previously created. This is easiest done using MonitoredTrainingSession
with tf.train.MonitoredTrainingSession(
    checkpoint_dir='H2/ferminet_results_Wed_Dec__9_08\:37\:23_2020/checkpoints',
    save_checkpoint_steps=None,  # we haven't created a global step in this example, so disable automatic checkpointing every Y steps
    save_summaries_steps=None,
    log_step_count_steps=None) as session:
  psi_1 = session.run(psi)

If this is successful, you should (you might need to adjust the log level settings -- see TF docs) see something like:

INFO:tensorflow:Graph was finalized.
INFO:tensorflow:Restoring parameters from H2/ferminet_results_Wed_Dec__9_08\:37\:23_2020/checkpoints/model.ckpt-207
INFO:tensorflow:Running local_init_op.
INFO:tensorflow:Done running local_init_op.

If the checkpoiint is not restored, only the first line and last two lines will be printed.

Note that the above code becomes a fair amount more complicated if you are using multiple GPUs. Please refer to the TensorFlow DistributionStrategy documentation in this case -- the code in train.py and qmc.py for model construction and the training loop should be a good example of the steps required to set this up.

  1. Another option is to simply return the network at the end of train.train. :)