google-deepmind / ferminet

An implementation of the Fermionic Neural Network for ab-initio electronic structure calculations
Apache License 2.0
721 stars 120 forks source link

Question About load Checkpoint #60

Closed ZiciuCanJustus closed 1 year ago

ZiciuCanJustus commented 1 year ago

Hello, I have a question about loading checkpoint function. To the best of my knowledge, we can save the model by /ferminet/train.py in :

      if time.time() - time_of_last_ckpt > cfg.log.save_frequency * 60:
        checkpoint.save(ckpt_save_path, t, data, params, opt_state, mcmc_width)
        time_of_last_ckpt = time.time()
        sys.exit(0)

and this function is implemented by np.savez. However, when I attempt to load this chekpoint, it will not pass the check logic in checkpoint.resotre, specifically:

  with open(restore_filename, 'rb') as f:
    ckpt_data = np.load(f, allow_pickle=True)
    # Retrieve data from npz file. Non-array variables need to be converted back
    # to natives types using .tolist().
    t = ckpt_data['t'].tolist() + 1  # Return the iterations completed.
    data = ckpt_data['data']
    params = ckpt_data['params'].tolist()
    opt_state = ckpt_data['opt_state'].tolist()
    mcmc_width = jnp.array(ckpt_data['mcmc_width'].tolist())
    if data.shape[0] != jax.device_count():
      raise ValueError(
          f'Incorrect number of devices found. Expected {data.shape[0]}, found '
          f'{jax.device_count()}.')

I attempt to alleviate this issue, and I found that for checkpoint.save function, the data is FermiNetData class, which contains four array named position, spins, atoms, charges, respectively. However, when I load this numpy checkpoint, the data is an array merely consists of strings [position, spins, atoms, charges]. It seems that this part may have some questions? I'm wondering wheter this part should be correct? With great appreciate for your time and efforts in reading my issue.

jsspencer commented 1 year ago

Sorry, this was broken by mistake with some refactoring. This is now fixed.