google-deepmind / ferminet

An implementation of the Fermionic Neural Network for ab-initio electronic structure calculations
Apache License 2.0
721 stars 120 forks source link

Evaluating logprob using batch_network in train #61

Closed Baiyu-Su closed 1 year ago

Baiyu-Su commented 1 year ago

Hi all,

To evaluate the model performance, I need to compute the log probability of a new set of data positions inside the main training loop. Calling batch_network = constants.pmap(batch_network) with input batch_network as the vmapped function as defined in train function, and log_prob = 2.0 * batch_network(params, x, data.spins, data.atoms, data.charges) where x is a jax array having the same dimension as data.positions.

However, I encountered the following issue

log_prob = 2.0 * batch_network(params, data.positions, data.spins, data.atoms, data.charges)

ValueError: pmap got inconsistent sizes for array axes to be mapped:`
  * most axes (6 of them) had size 256, e.g. axis 0 of argument args[0]['layers']['streams'][0]['single']['b'] of type float32[256];
  * some axes (5 of them) had size 32, e.g. axis 0 of argument args[0]['layers']['streams'][0]['double']['b'] of type float32[32];
  * some axes (4 of them) had size 6, e.g. axis 0 of argument args[0]['envelope'][0]['pi'] of type float32[6,256];
  * some axes (4 of them) had size 1024, e.g. axis 0 of argument args[1] of type float32[1024,48];
  * some axes (3 of them) had size 832, e.g. axis 0 of argument args[0]['layers']['streams'][1]['single']['w'] of type float32[832,256];
  * one axis had size 4: axis 0 of argument args[0]['layers']['streams'][0]['double']['w'] of type float32[4,32];
  * one axis had size 80: axis 0 of argument args[0]['layers']['streams'][0]['single']['w'] of type float32[80,256]

I tried to avoid pmapping the params input by setting in_axes=(None, 0, 0, 0, 0) and get

File "/home/baiyu/ferminet/train.py", line 828, in train
    log_prob = 2.0 * batch_network(params, data.positions, data.spins, data.atoms, data.charges)
  File "/home/baiyu/ferminet/train.py", line 558, in <lambda>
    logabs_network = lambda *args, **kwargs: signed_network(*args, **kwargs)[1]
  File "/home/baiyu/ferminet/networks.py", line 1387, in apply
    orbitals = orbitals_apply(params, pos, spins, atoms, charges)
  File "/home/baiyu/ferminet/networks.py", line 1171, in apply
    h_to_orbitals = equivariant_layers_apply(
  File "/home/baiyu/ferminet/networks.py", line 1029, in apply
    h_one, h_two, h_elec_ion = apply_layer(
  File "/home/baiyu/ferminet/networks.py", line 937, in apply_layer
    h_one_in = construct_symmetric_features(
  File "/home/baiyu/ferminet/networks.py", line 552, in construct_symmetric_features
    return jnp.concatenate(features, axis=1)
  File "/home/baiyu/miniconda/envs/newenv/lib/python3.10/site-packages/jax/_src/numpy/lax_numpy.py", line 1837, in concatenate
    arrays_out = [lax.concatenate(arrays_out[i:i+k], axis)
  File "/home/baiyu/miniconda/envs/newenv/lib/python3.10/site-packages/jax/_src/numpy/lax_numpy.py", line 1837, in <listcomp>

arrays_out = [lax.concatenate(arrays_out[i:i+k], axis)
TypeError: Cannot concatenate arrays with shapes that differ in dimensions other than the one being concatenated: concatenating along dimension 2 for shapes (1024, 16, 1, 256), (1024, 1, 16, 256), (1024, 1, 16, 256), (1024, 16, 1, 32), (1024, 16, 1, 32).

And replacing the x by the default data.positions will give rise to the exact same issue. Without these two lines of code everything else works fine.

How could I evaluate the value of log probability correctly in the training loop?

Thanks!

dpfau commented 1 year ago

Hi,

It's difficult to say without knowing what the shapes of params, x, data.spins and data.atoms are. I would advise starting with the unbatched network, and just putting in subsets of the params/position/spins/atoms that have the correct leading dimensions. Good luck.

Also, please only open an issue on Github if there is a bug or a feature request for the package itself. It sounds like nothing is actually wrong with the package, so I'm closing the issue.

David