google-deepmind / ferminet

An implementation of the Fermionic Neural Network for ab-initio electronic structure calculations
Apache License 2.0
740 stars 129 forks source link

kfac_jax error when running H2 example script #69

Closed eliasteve closed 1 year ago

eliasteve commented 1 year ago

Hi, I'm trying to run the example script for the H2 molecule on Colab, but I run into an attribute error on the first iteration of training. Below are the commands I'm using to install the relevant packages:

pip install git+https://github.com/deepmind/ferminet@main
pip install numpy==1.26.0 #Error in importing pyscf with version 1.26.1

the script I'm trying to run:

import sys

from absl import logging
from ferminet.utils import system
from ferminet import base_config
from ferminet import train

# Optional, for also printing training progress to STDOUT.
# If running a script, you can also just use the --alsologtostderr flag.
logging.get_absl_handler().python_handler.stream = sys.stdout
logging.set_verbosity(logging.INFO)

# Define H2 molecule
cfg = base_config.default()
cfg.system.electrons = (1,1)  # (alpha electrons, beta electrons)
cfg.system.molecule = [system.Atom('H', (0, 0, -1)), system.Atom('H', (0, 0, 1))]

# Set training parameters
cfg.batch_size = 256
cfg.pretrain.iterations = 100

train.train(cfg)

as well as the full printout of the error message:

INFO:absl:Starting QMC with 1 XLA devices per host across 1 hosts.
converged SCF energy = -1.05642988216974  <S^2> = -4.4408921e-16  2S+1 = 1
INFO:absl:No checkpoint found. Training new model.
INFO:absl:Pretrain iter 00000: 0.0987188
INFO:absl:Pretrain iter 00001: 0.0468968
INFO:absl:Pretrain iter 00002: 0.021173
INFO:absl:Pretrain iter 00003: 0.012926
INFO:absl:Pretrain iter 00004: 0.012574
INFO:absl:Pretrain iter 00005: 0.0137347
INFO:absl:Pretrain iter 00006: 0.0137845
INFO:absl:Pretrain iter 00007: 0.012732
INFO:absl:Pretrain iter 00008: 0.0110899
INFO:absl:Pretrain iter 00009: 0.00934778
INFO:absl:Pretrain iter 00010: 0.00778915
INFO:absl:Pretrain iter 00011: 0.00656673
INFO:absl:Pretrain iter 00012: 0.00565565
INFO:absl:Pretrain iter 00013: 0.00497928
INFO:absl:Pretrain iter 00014: 0.00446487
INFO:absl:Pretrain iter 00015: 0.00406806
INFO:absl:Pretrain iter 00016: 0.00370217
INFO:absl:Pretrain iter 00017: 0.00335341
INFO:absl:Pretrain iter 00018: 0.00301634
INFO:absl:Pretrain iter 00019: 0.00270557
INFO:absl:Pretrain iter 00020: 0.00244722
INFO:absl:Pretrain iter 00021: 0.00228125
INFO:absl:Pretrain iter 00022: 0.00220615
INFO:absl:Pretrain iter 00023: 0.00218246
INFO:absl:Pretrain iter 00024: 0.00216123
INFO:absl:Pretrain iter 00025: 0.00209267
INFO:absl:Pretrain iter 00026: 0.00195859
INFO:absl:Pretrain iter 00027: 0.0017916
INFO:absl:Pretrain iter 00028: 0.00160594
INFO:absl:Pretrain iter 00029: 0.00143077
INFO:absl:Pretrain iter 00030: 0.0012787
INFO:absl:Pretrain iter 00031: 0.00115612
INFO:absl:Pretrain iter 00032: 0.00107149
INFO:absl:Pretrain iter 00033: 0.00101318
INFO:absl:Pretrain iter 00034: 0.00097588
INFO:absl:Pretrain iter 00035: 0.000957538
INFO:absl:Pretrain iter 00036: 0.000952348
INFO:absl:Pretrain iter 00037: 0.000960992
INFO:absl:Pretrain iter 00038: 0.000957644
INFO:absl:Pretrain iter 00039: 0.00094491
INFO:absl:Pretrain iter 00040: 0.000913838
INFO:absl:Pretrain iter 00041: 0.000867143
INFO:absl:Pretrain iter 00042: 0.000815364
INFO:absl:Pretrain iter 00043: 0.000763372
INFO:absl:Pretrain iter 00044: 0.000719524
INFO:absl:Pretrain iter 00045: 0.00068491
INFO:absl:Pretrain iter 00046: 0.000651115
INFO:absl:Pretrain iter 00047: 0.000626584
INFO:absl:Pretrain iter 00048: 0.000609355
INFO:absl:Pretrain iter 00049: 0.000602466
INFO:absl:Pretrain iter 00050: 0.000594917
INFO:absl:Pretrain iter 00051: 0.000595118
INFO:absl:Pretrain iter 00052: 0.000592292
INFO:absl:Pretrain iter 00053: 0.000585884
INFO:absl:Pretrain iter 00054: 0.000574762
INFO:absl:Pretrain iter 00055: 0.000557475
INFO:absl:Pretrain iter 00056: 0.000534383
INFO:absl:Pretrain iter 00057: 0.000514163
INFO:absl:Pretrain iter 00058: 0.000497474
INFO:absl:Pretrain iter 00059: 0.000483247
INFO:absl:Pretrain iter 00060: 0.000474224
INFO:absl:Pretrain iter 00061: 0.000468507
INFO:absl:Pretrain iter 00062: 0.000463785
INFO:absl:Pretrain iter 00063: 0.000460483
INFO:absl:Pretrain iter 00064: 0.000456877
INFO:absl:Pretrain iter 00065: 0.000450266
INFO:absl:Pretrain iter 00066: 0.000445674
INFO:absl:Pretrain iter 00067: 0.00043909
INFO:absl:Pretrain iter 00068: 0.00043164
INFO:absl:Pretrain iter 00069: 0.000425373
INFO:absl:Pretrain iter 00070: 0.000416835
INFO:absl:Pretrain iter 00071: 0.000408702
INFO:absl:Pretrain iter 00072: 0.000404959
INFO:absl:Pretrain iter 00073: 0.000398033
INFO:absl:Pretrain iter 00074: 0.000392218
INFO:absl:Pretrain iter 00075: 0.000388479
INFO:absl:Pretrain iter 00076: 0.000385467
INFO:absl:Pretrain iter 00077: 0.000381409
INFO:absl:Pretrain iter 00078: 0.000377661
INFO:absl:Pretrain iter 00079: 0.00037244
INFO:absl:Pretrain iter 00080: 0.000368109
INFO:absl:Pretrain iter 00081: 0.000364462
INFO:absl:Pretrain iter 00082: 0.000360276
INFO:absl:Pretrain iter 00083: 0.0003561
INFO:absl:Pretrain iter 00084: 0.000352395
INFO:absl:Pretrain iter 00085: 0.000348523
INFO:absl:Pretrain iter 00086: 0.000344737
INFO:absl:Pretrain iter 00087: 0.000342277
INFO:absl:Pretrain iter 00088: 0.000339568
INFO:absl:Pretrain iter 00089: 0.000336274
INFO:absl:Pretrain iter 00090: 0.000333011
INFO:absl:Pretrain iter 00091: 0.000329609
INFO:absl:Pretrain iter 00092: 0.000326805
INFO:absl:Pretrain iter 00093: 0.000322345
INFO:absl:Pretrain iter 00094: 0.000320427
INFO:absl:Pretrain iter 00095: 0.000316732
INFO:absl:Pretrain iter 00096: 0.000314699
INFO:absl:Pretrain iter 00097: 0.000311589
INFO:absl:Pretrain iter 00098: 0.000308642
INFO:absl:Pretrain iter 00099: 0.000305587
INFO:absl:==================================================
INFO:absl:Graph parameter registrations:
INFO:absl:{'envelope': [{'pi': 'Auto[scale_and_shift_tag_1]',
               'sigma': 'Auto[scale_and_shift_tag_0]'},
              {'pi': 'Auto[scale_and_shift_tag_3]',
               'sigma': 'Auto[scale_and_shift_tag_2]'}],
 'layers': {'input': {},
            'streams': [{'double': {'b': 'Auto[repeated_dense_tag_1]',
                                    'w': 'Auto[repeated_dense_tag_1]'},
                         'single': {'b': 'Auto[repeated_dense_tag_0]',
                                    'w': 'Auto[repeated_dense_tag_0]'}},
                        {'double': {'b': 'Auto[repeated_dense_tag_3]',
                                    'w': 'Auto[repeated_dense_tag_3]'},
                         'single': {'b': 'Auto[repeated_dense_tag_2]',
                                    'w': 'Auto[repeated_dense_tag_2]'}},
                        {'double': {'b': 'Auto[repeated_dense_tag_5]',
                                    'w': 'Auto[repeated_dense_tag_5]'},
                         'single': {'b': 'Auto[repeated_dense_tag_4]',
                                    'w': 'Auto[repeated_dense_tag_4]'}},
                        {'single': {'b': 'Auto[repeated_dense_tag_6]',
                                    'w': 'Auto[repeated_dense_tag_6]'}}]},
 'orbital': [{'w': 'Auto[repeated_dense_tag_7]'},
             {'w': 'Auto[repeated_dense_tag_8]'}]}
INFO:absl:==================================================
INFO:absl:Burning in MCMC chain for 100 steps
INFO:absl:Completed burn-in MCMC steps
INFO:absl:Initial energy: -1.1759 E_h
---------------------------------------------------------------------------
AttributeError                            Traceback (most recent call last)
<ipython-input-1-df9aebf4a1f2> in <cell line: 22>()
     20 cfg.pretrain.iterations = 100
     21 
---> 22 train.train(cfg)

16 frames
/usr/local/lib/python3.10/dist-packages/ferminet/train.py in train(cfg, writer_manager)
    712     for t in range(t_init, cfg.optim.iterations):
    713       sharded_key, subkeys = kfac_jax.utils.p_split(sharded_key)
--> 714       data, params, opt_state, loss, unused_aux_data, pmove = step(
    715           data,
    716           params,

/usr/local/lib/python3.10/dist-packages/ferminet/train.py in step(data, params, state, key, mcmc_width)
    313 
    314     # Optimization step
--> 315     new_params, state, stats = optimizer.step(
    316         params=params,
    317         state=state,

/usr/local/lib/python3.10/dist-packages/kfac_jax/_src/optimizer.py in step(self, params, state, rng, data_iterator, batch, func_state, learning_rate, momentum, damping, global_step_int)
   1214       batch = next(data_iterator)
   1215 
-> 1216     return self._step(params, state, rng, batch, func_state,
   1217                       learning_rate, momentum, damping)
   1218 

/usr/local/lib/python3.10/dist-packages/kfac_jax/_src/utils/staging.py in decorated(instance, *args)
    246           pmap_funcs[key] = func
    247 
--> 248         outs = func(instance, *args)
    249 
    250       else:

    [... skipping hidden 12 frame]

/usr/local/lib/python3.10/dist-packages/kfac_jax/_src/utils/misc.py in wrapped(instance, *args, **kwargs)
    283       method_name = method_name[1:]
    284     with jax.named_scope(f"{class_name}_{method_name}"):
--> 285       return method(instance, *args, **kwargs)
    286 
    287   return wrapped

/usr/local/lib/python3.10/dist-packages/kfac_jax/_src/optimizer.py in _step(self, params, state, rng, batch, func_state, learning_rate, momentum, damping)
   1020 
   1021     # Update curvature estimate
-> 1022     state = self._maybe_update_estimator_curvature(
   1023         state,
   1024         func_args,

/usr/local/lib/python3.10/dist-packages/kfac_jax/_src/optimizer.py in _maybe_update_estimator_curvature(self, state, func_args, rng, ema_old, ema_new, sync)
    723   ) -> "Optimizer.State":
    724     """Updates the curvature estimates if it is the right iteration."""
--> 725     return self._maybe_update_estimator_state(
    726         state,
    727         self.should_update_estimate_curvature(state),

/usr/local/lib/python3.10/dist-packages/kfac_jax/_src/optimizer.py in _maybe_update_estimator_state(self, state, should_update, update_func, **update_func_kwargs)
    678     state = state.copy()
    679 
--> 680     state.estimator_state = lax.cond(
    681         should_update,
    682         functools.partial(update_func, **update_func_kwargs),

    [... skipping hidden 13 frame]

/usr/local/lib/python3.10/dist-packages/kfac_jax/_src/optimizer.py in _update_estimator_curvature(self, estimator_state, func_args, rng, ema_old, ema_new, sync)
    696   ) -> curvature_estimator.BlockDiagonalCurvature.State:
    697     """Updates the curvature estimator state."""
--> 698     state = self.estimator.update_curvature_matrix_estimate(
    699         state=estimator_state,
    700         ema_old=ema_old,

/usr/local/lib/python3.10/dist-packages/kfac_jax/_src/utils/misc.py in wrapped(instance, *args, **kwargs)
    283       method_name = method_name[1:]
    284     with jax.named_scope(f"{class_name}_{method_name}"):
--> 285       return method(instance, *args, **kwargs)
    286 
    287   return wrapped

/usr/local/lib/python3.10/dist-packages/kfac_jax/_src/curvature_estimator.py in update_curvature_matrix_estimate(self, state, ema_old, ema_new, batch_size, rng, func_args, estimation_mode)
   1239 
   1240     # Compute the losses and the VJP function from the function inputs
-> 1241     losses, losses_vjp = self._compute_losses_vjp(func_args)
   1242 
   1243     if "fisher" in estimation_mode:

/usr/local/lib/python3.10/dist-packages/kfac_jax/_src/utils/misc.py in wrapped(instance, *args, **kwargs)
    283       method_name = method_name[1:]
    284     with jax.named_scope(f"{class_name}_{method_name}"):
--> 285       return method(instance, *args, **kwargs)
    286 
    287   return wrapped

/usr/local/lib/python3.10/dist-packages/kfac_jax/_src/curvature_estimator.py in _compute_losses_vjp(self, func_args)
   1037   def _compute_losses_vjp(self, func_args: utils.FuncArgs):
   1038     """Computes all model statistics needed for estimating the curvature."""
-> 1039     return self._vjp(func_args)
   1040 
   1041   def params_vector_to_blocks_vectors(

/usr/local/lib/python3.10/dist-packages/kfac_jax/_src/tracer.py in wrapped_transformation(func_args, return_only_jaxpr, *args)
    379       return jaxpr
    380     else:
--> 381       return f(func_args, *args)
    382 
    383   return wrapped_transformation

/usr/local/lib/python3.10/dist-packages/kfac_jax/_src/tracer.py in _layer_tag_vjp(processed_jaxpr, primal_func_args)
    781 
    782   # First compute the primal values for the inputs to all layer tags
--> 783   layer_input_values = forward()
    784   primals_dict = dict(zip(layer_input_vars, layer_input_values))
    785 

/usr/local/lib/python3.10/dist-packages/kfac_jax/_src/tracer.py in forward()
    698     for eqn in processed_jaxpr.jaxpr.eqns:
    699 
--> 700       write(eqn.outvars, tgm.eval_jaxpr_eqn(eqn, read(eqn.invars)))
    701 
    702       if isinstance(eqn.primitive, tags.LossTag):

/usr/local/lib/python3.10/dist-packages/kfac_jax/_src/tag_graph_matcher.py in eval_jaxpr_eqn(eqn, in_values)
     63 
     64   if jax_version > (0, 4, 11):
---> 65     user_context = jax_extend.source_info_util.user_context
     66   else:
     67     user_context = jax.core.source_info_util.user_context

AttributeError: module 'jax.extend' has no attribute 'source_info_util'
jsspencer commented 1 year ago

I cannot reproduce this with jax 0.4.19. What version of jax are you using?

Also pyscf 2.4.0 works fine for me with numpy 1.26.1 -- are you also using an older version of pyscf?

pip freeze is helpful to show the versions of all modules installed.

eliasteve commented 1 year ago

The version of jax was 0.4.16, I've installed version 0.4.19 and the code runs fine, thanks! As for pyscf, I'm running version 2.4.0 as well, but when I try importing it (with numpy version 1.26.1) I get

---------------------------------------------------------------------------
AttributeError                            Traceback (most recent call last)
[<ipython-input-2-39797ef5fbb5>](https://localhost:8080/#) in <cell line: 1>()
----> 1 import pyscf

10 frames
[/usr/local/lib/python3.10/dist-packages/pyscf/__init__.py](https://localhost:8080/#) in <module>
     97 from pyscf import lib
     98 from pyscf import gto
---> 99 from pyscf import scf
    100 from pyscf import ao2mo
    101 

[/usr/local/lib/python3.10/dist-packages/pyscf/scf/__init__.py](https://localhost:8080/#) in <module>
    103 
    104 from pyscf import gto
--> 105 from pyscf.scf import hf
    106 rhf = hf
    107 from pyscf.scf import rohf

[/usr/local/lib/python3.10/dist-packages/pyscf/scf/hf.py](https://localhost:8080/#) in <module>
     31 from pyscf import lib
     32 from pyscf.lib import logger
---> 33 from pyscf.scf import diis
     34 from pyscf.scf import _vhf
     35 from pyscf.scf import chkfile

[/usr/local/lib/python3.10/dist-packages/pyscf/scf/diis.py](https://localhost:8080/#) in <module>
     24 import numpy
     25 import scipy.linalg
---> 26 import scipy.optimize
     27 from pyscf import lib
     28 from pyscf.lib import logger

[/usr/local/lib/python3.10/dist-packages/scipy/optimize/__init__.py](https://localhost:8080/#) in <module>
    408 
    409 from ._optimize import *
--> 410 from ._minimize import *
    411 from ._root import *
    412 from ._root_scalar import *

[/usr/local/lib/python3.10/dist-packages/scipy/optimize/_minimize.py](https://localhost:8080/#) in <module>
     25 from ._trustregion_krylov import _minimize_trust_krylov
     26 from ._trustregion_exact import _minimize_trustregion_exact
---> 27 from ._trustregion_constr import _minimize_trustregion_constr
     28 
     29 # constrained minimization

[/usr/local/lib/python3.10/dist-packages/scipy/optimize/_trustregion_constr/__init__.py](https://localhost:8080/#) in <module>
      2 
      3 
----> 4 from .minimize_trustregion_constr import _minimize_trustregion_constr
      5 
      6 __all__ = ['_minimize_trustregion_constr']

[/usr/local/lib/python3.10/dist-packages/scipy/optimize/_trustregion_constr/minimize_trustregion_constr.py](https://localhost:8080/#) in <module>
      3 from scipy.sparse.linalg import LinearOperator
      4 from .._differentiable_functions import VectorFunction
----> 5 from .._constraints import (
      6     NonlinearConstraint, LinearConstraint, PreparedConstraint, Bounds, strict_bounds)
      7 from .._hessian_update_strategy import BFGS

[/usr/local/lib/python3.10/dist-packages/scipy/optimize/_constraints.py](https://localhost:8080/#) in <module>
      6 from ._optimize import OptimizeWarning
      7 from warnings import warn, catch_warnings, simplefilter
----> 8 from numpy.testing import suppress_warnings
      9 from scipy.sparse import issparse
     10 

[/usr/local/lib/python3.10/dist-packages/numpy/testing/__init__.py](https://localhost:8080/#) in <module>
      9 
     10 from . import _private
---> 11 from ._private.utils import *
     12 from ._private.utils import (_assert_valid_refcount, _gen_alignment_data)
     13 from ._private import extbuild

[/usr/local/lib/python3.10/dist-packages/numpy/testing/_private/utils.py](https://localhost:8080/#) in <module>
     55 IS_PYSTON = hasattr(sys, "pyston_version_info")
     56 HAS_REFCOUNT = getattr(sys, 'getrefcount', None) is not None and not IS_PYSTON
---> 57 HAS_LAPACK64 = numpy.linalg._umath_linalg._ilp64
     58 
     59 _OLD_PROMOTION = lambda: np._get_promotion_state() == 'legacy'

AttributeError: module 'numpy.linalg._umath_linalg' has no attribute '_ilp64'

However, this is not crucial because using version 1.26.0 of numpy solves the issue for me.