I am running with jaxlib==0.3.0 on cuda11 and it starts on my two V100 GPUs but stops with this:
Traceback (most recent call last):
File "/opt/conda/bin/ferminet", line 7, in
exec(compile(f.read(), file, 'exec'))
File "/home/jaxelsen/aisecurity/Ferminet_google/bin/ferminet", line 39, in
app.run(main)
File "/opt/conda/lib/python3.7/site-packages/absl/app.py", line 312, in run
_run_main(main, args)
File "/opt/conda/lib/python3.7/site-packages/absl/app.py", line 258, in _run_main
sys.exit(main(argv))
File "/home/jaxelsen/aisecurity/Ferminet_google/bin/ferminet", line 35, in main
train.train(cfg)
File "/home/jaxelsen/aisecurity/Ferminet_google/ferminet/train.py", line 450, in train
opt_state = optimizer.init(params, subkeys, data)
File "/opt/conda/lib/python3.7/site-packages/kfac_ferminet_alpha/optimizer.py", line 498, in init
self.finalize(params, rng, batch, func_state)
File "/opt/conda/lib/python3.7/site-packages/kfac_ferminet_alpha/optimizer.py", line 244, in finalize
patterns_to_skip=self.patterns_to_skip)
File "/opt/conda/lib/python3.7/site-packages/kfac_ferminet_alpha/tag_graph_matcher.py", line 468, in auto_register_tags
graph = function_to_jax_graph(func, func_args, params_index=params_index)
File "/opt/conda/lib/python3.7/site-packages/kfac_ferminet_alpha/tag_graph_matcher.py", line 437, in function_to_jax_graph
typed_jaxpr = jax.make_jaxpr(func)(*args)
File "/opt/conda/lib/python3.7/site-packages/kfac_ferminet_alpha/tag_graph_matcher.py", line 256, in merged_func
evaluate_eqn(eqn, jax_util.safe_map(read, eqn.invars), write)
File "/opt/conda/lib/python3.7/site-packages/kfac_ferminet_alpha/tag_graph_matcher.py", line 172, in evaluate_eqn
call_jaxpr, params = jax.core.extract_call_jaxpr(eqn.primitive, eqn.params)
AttributeError: module 'jax.core' has no attribute 'extract_call_jaxpr'
I am running with jaxlib==0.3.0 on cuda11 and it starts on my two V100 GPUs but stops with this:
Traceback (most recent call last): File "/opt/conda/bin/ferminet", line 7, in
exec(compile(f.read(), file, 'exec'))
File "/home/jaxelsen/aisecurity/Ferminet_google/bin/ferminet", line 39, in
app.run(main)
File "/opt/conda/lib/python3.7/site-packages/absl/app.py", line 312, in run
_run_main(main, args)
File "/opt/conda/lib/python3.7/site-packages/absl/app.py", line 258, in _run_main
sys.exit(main(argv))
File "/home/jaxelsen/aisecurity/Ferminet_google/bin/ferminet", line 35, in main
train.train(cfg)
File "/home/jaxelsen/aisecurity/Ferminet_google/ferminet/train.py", line 450, in train
opt_state = optimizer.init(params, subkeys, data)
File "/opt/conda/lib/python3.7/site-packages/kfac_ferminet_alpha/optimizer.py", line 498, in init
self.finalize(params, rng, batch, func_state)
File "/opt/conda/lib/python3.7/site-packages/kfac_ferminet_alpha/optimizer.py", line 244, in finalize
patterns_to_skip=self.patterns_to_skip)
File "/opt/conda/lib/python3.7/site-packages/kfac_ferminet_alpha/tag_graph_matcher.py", line 468, in auto_register_tags
graph = function_to_jax_graph(func, func_args, params_index=params_index)
File "/opt/conda/lib/python3.7/site-packages/kfac_ferminet_alpha/tag_graph_matcher.py", line 437, in function_to_jax_graph
typed_jaxpr = jax.make_jaxpr(func)(*args)
File "/opt/conda/lib/python3.7/site-packages/kfac_ferminet_alpha/tag_graph_matcher.py", line 256, in merged_func
evaluate_eqn(eqn, jax_util.safe_map(read, eqn.invars), write)
File "/opt/conda/lib/python3.7/site-packages/kfac_ferminet_alpha/tag_graph_matcher.py", line 172, in evaluate_eqn
call_jaxpr, params = jax.core.extract_call_jaxpr(eqn.primitive, eqn.params)
AttributeError: module 'jax.core' has no attribute 'extract_call_jaxpr'