Closed Baiyu-Su closed 1 year ago
Hi,
It's difficult to say without knowing what the shapes of params
, x
, data.spins
and data.atoms
are. I would advise starting with the unbatched network, and just putting in subsets of the params/position/spins/atoms that have the correct leading dimensions. Good luck.
Also, please only open an issue on Github if there is a bug or a feature request for the package itself. It sounds like nothing is actually wrong with the package, so I'm closing the issue.
David
Hi all,
To evaluate the model performance, I need to compute the log probability of a new set of data positions inside the main training loop. Calling
batch_network = constants.pmap(batch_network)
with inputbatch_network
as the vmapped function as defined in train function, andlog_prob = 2.0 * batch_network(params, x, data.spins, data.atoms, data.charges)
wherex
is a jax array having the same dimension as data.positions.However, I encountered the following issue
I tried to avoid pmapping the params input by setting
in_axes=(None, 0, 0, 0, 0)
and getAnd replacing the
x
by the defaultdata.positions
will give rise to the exact same issue. Without these two lines of code everything else works fine.How could I evaluate the value of log probability correctly in the training loop?
Thanks!