google-deepmind / ferminet

An implementation of the Fermionic Neural Network for ab-initio electronic structure calculations
Apache License 2.0
740 stars 129 forks source link

Upstream breaking change in `kfac-jax` #70

Closed gcassella closed 11 months ago

gcassella commented 1 year ago

The most recent commit to the kfac-jax repo (at the time of writing, f466559d86b07d6a2291cc699ac769c8e0931592) contains a breaking change for the ferminet repository. Last working commit is bacdf8eaf4f5bd1a467b7e9d9703e571ed37c897. Following the installation / usage instructions in README.md will result in a broken installation as a result.

To reproduce, install as per usual instructions and run:

import sys

from absl import logging
from ferminet.utils import system
from ferminet import base_config
from ferminet import train

# Settings in a config files are loaded by executing the the get_config
# function.
def get_config():
  # Get default options.
  cfg = base_config.default()
  # Set up molecule
  cfg.system.electrons = (1,1)
  cfg.system.molecule = [system.Atom('H', (0, 0, -1)), system.Atom('H', (0, 0, 1))]

  # Set training hyperparameters
  cfg.batch_size = 256
  cfg.pretrain.iterations = 100

  return cfg

# Optional, for also printing training progress to STDOUT.
# If running a script, you can also just use the --alsologtostderr flag.
logging.get_absl_handler().python_handler.stream = sys.stdout
logging.set_verbosity(logging.INFO)

# Define H2 molecule
cfg = get_config()
train.train(cfg)

resulting in stack trace,

Traceback (most recent call last):
  File "/home/ettore/ferminet/test.py", line 6, in <module>
    from ferminet import train
  File "/home/ettore/ferminet/ferminet/train.py", line 24, in <module>
    from ferminet import checkpoint
  File "/home/ettore/ferminet/ferminet/checkpoint.py", line 24, in <module>
    from ferminet import networks
  File "/home/ettore/ferminet/ferminet/networks.py", line 21, in <module>
    from ferminet import envelopes
  File "/home/ettore/ferminet/ferminet/envelopes.py", line 21, in <module>
    from ferminet import curvature_tags_and_blocks
  File "/home/ettore/ferminet/ferminet/curvature_tags_and_blocks.py", line 27, in <module>
    vmap_psd_inv_cholesky = jax.vmap(kfac_jax.utils.psd_inv_cholesky, (0, None), 0)
AttributeError: module 'kfac_jax._src.utils' has no attribute 'psd_inv_cholesky'