google-deepmind / ferminet

An implementation of the Fermionic Neural Network for ab-initio electronic structure calculations
Apache License 2.0
740 stars 129 forks source link

AttributeError: module 'jax.core' has no attribute 'extract_call_jaxpr' #44

Closed draxelsen closed 2 years ago

draxelsen commented 2 years ago

I am running with jaxlib==0.3.0 on cuda11 and it starts on my two V100 GPUs but stops with this:

Traceback (most recent call last): File "/opt/conda/bin/ferminet", line 7, in exec(compile(f.read(), file, 'exec')) File "/home/jaxelsen/aisecurity/Ferminet_google/bin/ferminet", line 39, in app.run(main) File "/opt/conda/lib/python3.7/site-packages/absl/app.py", line 312, in run _run_main(main, args) File "/opt/conda/lib/python3.7/site-packages/absl/app.py", line 258, in _run_main sys.exit(main(argv)) File "/home/jaxelsen/aisecurity/Ferminet_google/bin/ferminet", line 35, in main train.train(cfg) File "/home/jaxelsen/aisecurity/Ferminet_google/ferminet/train.py", line 450, in train opt_state = optimizer.init(params, subkeys, data) File "/opt/conda/lib/python3.7/site-packages/kfac_ferminet_alpha/optimizer.py", line 498, in init self.finalize(params, rng, batch, func_state) File "/opt/conda/lib/python3.7/site-packages/kfac_ferminet_alpha/optimizer.py", line 244, in finalize patterns_to_skip=self.patterns_to_skip) File "/opt/conda/lib/python3.7/site-packages/kfac_ferminet_alpha/tag_graph_matcher.py", line 468, in auto_register_tags graph = function_to_jax_graph(func, func_args, params_index=params_index) File "/opt/conda/lib/python3.7/site-packages/kfac_ferminet_alpha/tag_graph_matcher.py", line 437, in function_to_jax_graph typed_jaxpr = jax.make_jaxpr(func)(*args) File "/opt/conda/lib/python3.7/site-packages/kfac_ferminet_alpha/tag_graph_matcher.py", line 256, in merged_func evaluate_eqn(eqn, jax_util.safe_map(read, eqn.invars), write) File "/opt/conda/lib/python3.7/site-packages/kfac_ferminet_alpha/tag_graph_matcher.py", line 172, in evaluate_eqn call_jaxpr, params = jax.core.extract_call_jaxpr(eqn.primitive, eqn.params) AttributeError: module 'jax.core' has no attribute 'extract_call_jaxpr'

jsspencer commented 2 years ago

See #43