google-deepmind / ferminet

An implementation of the Fermionic Neural Network for ab-initio electronic structure calculations
Apache License 2.0
721 stars 120 forks source link

How to reproduce the results for Neon. JAX. #34

Closed ghost closed 3 years ago

ghost commented 3 years ago

Hi, I am trying to reproduce the results for Neon, I am running the following code with default base config and only changes to batch_size = 256, pretrain iterations = 100 and optim iterations = 100_000 (for now, will be increased if results not matched):

Training Code
import sys

from absl import logging
from ferminet.utils import system
from ferminet import base_config
from ferminet import train
import numpy as np

logging.get_absl_handler().python_handler.stream = sys.stdout
logging.set_verbosity(logging.INFO)

cfg.system.electrons = (5,5)
cfg.system.molecule = [system.Atom('Ne')]

cfg.batch_size = 256
cfg.pretrain.iterations = 100

train.train(cfg)
Loading Model
with open('ferminet_2021_08_22_16:24:03/qmcjax_ckpt_099929.npz', 'rb') as f:
    params = dict(np.load(f, allow_pickle=True))['params'].tolist()

with open('ferminet_2021_08_22_16:24:03/qmcjax_ckpt_099929.npz', 'rb') as f:
    data = dict(np.load(f, allow_pickle=True))['data']

with open(path+'geometry.npz', 'rb') as f:
    geometry = dict(np.load(f, allow_pickle=True))

foo = partial(networks.fermi_net, envelope_type='isotropic', full_det=False, **geometry)
# networks.fermi_net gives the sign/log of the wavefunction. We only care about the latter.
network = lambda p, x: foo(p, x)[1]
batch_network = jax.vmap(network, (None, 0), 0)
loss = train.make_loss(network, batch_network, geometry['atoms'], geometry['charges'], clip_local_energy=5.0)
ploss = jax.pmap(loss, axis_name='qmc_pmap_axis')  # right now, the code only works if the loss is wrapped by pmap

loss_ = ploss(params, data)  # For neon, should give -128.94165
loss_[0]

At this step loss_ = ploss(params, data), I am getting this error:

ValueError: Incompatible shapes for broadcasting: ((1, 5, 1, 1), (3, 3, 1, 160))

I compared my params with the cloud files given and it seems my pi and sigma envelopes have different shapes.

Any help on how to reproduce the pretrained results would be appreciated.

jsspencer commented 3 years ago

You need to use the same settings in networks.fermi_net as you used in training. In particular, the envelope_type and full_det arguments should match cfg.network.envelope_type and cfg.network.full_det.

ghost commented 3 years ago

is envelope_type and full_det in networks.fermi_net for the current Neon result the same as the current settings in cfg.network.envelope_type and cfg.network.full_det? If not, what should these values be?

Also, what are the param settings to reproduce the other elements? A list of params to reproduce the results would be helpful :)

jsspencer commented 3 years ago

By "current Neon result" do you mean from the Phys Rev Research paper? That was produced with the TF version (which is essentially deprecated). The isotropic envelope setting was introduced in the NeurIPS workshop paper, which gives motivation for this and comparison to the full envelope. All results published before then used the full envelope setting.

The full_det setting is experimental. Set it to False to match the published results.

For neon, neither of these settings will make a substantial difference to the final energy within statistical errors. Note your batch size is very small and might limit the accuracy you achieve.

Please refer to our papers for the settings used in the models. The PRR paper used the same set of configuration options for all experiments, except where noted, and this broadly matches the current defaults (except for full_det and envelope_type).