and this function is implemented by np.savez.
However, when I attempt to load this chekpoint, it will not pass the check logic in checkpoint.resotre, specifically:
with open(restore_filename, 'rb') as f:
ckpt_data = np.load(f, allow_pickle=True)
# Retrieve data from npz file. Non-array variables need to be converted back
# to natives types using .tolist().
t = ckpt_data['t'].tolist() + 1 # Return the iterations completed.
data = ckpt_data['data']
params = ckpt_data['params'].tolist()
opt_state = ckpt_data['opt_state'].tolist()
mcmc_width = jnp.array(ckpt_data['mcmc_width'].tolist())
if data.shape[0] != jax.device_count():
raise ValueError(
f'Incorrect number of devices found. Expected {data.shape[0]}, found '
f'{jax.device_count()}.')
I attempt to alleviate this issue, and I found that for checkpoint.save function, the data is FermiNetData class, which contains four array named position, spins, atoms, charges, respectively. However, when I load this numpy checkpoint, the data is an array merely consists of strings [position, spins, atoms, charges]. It seems that this part may have some questions?
I'm wondering wheter this part should be correct? With great appreciate for your time and efforts in reading my issue.
Hello, I have a question about loading checkpoint function. To the best of my knowledge, we can save the model by /ferminet/train.py in :
and this function is implemented by np.savez. However, when I attempt to load this chekpoint, it will not pass the check logic in checkpoint.resotre, specifically:
I attempt to alleviate this issue, and I found that for checkpoint.save function, the data is FermiNetData class, which contains four array named position, spins, atoms, charges, respectively. However, when I load this numpy checkpoint, the data is an array merely consists of strings [position, spins, atoms, charges]. It seems that this part may have some questions? I'm wondering wheter this part should be correct? With great appreciate for your time and efforts in reading my issue.