google-deepmind / ferminet

An implementation of the Fermionic Neural Network for ab-initio electronic structure calculations
Apache License 2.0
740 stars 129 forks source link

Program stuck in somewhere #78

Closed whopezx closed 2 months ago

whopezx commented 2 months ago

Hello,

When use nvidia A30 run the code, I find the program will stuck in pretrain.py after this function (define in function make_pretrain_step) return but use CPU will not.

def pretrain_step(data, params, state, key, scf_approx):
    """One iteration of pretraining to match HF."""
    val_and_grad = jax.value_and_grad(loss_fn, argnums=0)
    loss_val, search_direction = val_and_grad(params, data, scf_approx)
    search_direction = constants.pmean(search_direction)
    updates, state = optimizer_update(search_direction, state, params)
    params = optax.apply_updates(params, updates)
    full_params = {'ferminet': params, 'scf': scf_approx}
    data, pmove = mcmc_step(full_params, data, key, width=0.02)
    return data, params, state, loss_val, pmove

I use pdb check where the code stuck, finally find after this function return, program will call some jax internal code. Until this line, the program is stuck.

/home/ps/users/whb/natural_excited_states/venvcuda/ferminet/lib/python3.11/site-packages/jax/_src/api.py(1774)cache_miss()
-> out = map_bind_continuation(execute(*tracers))

I use top check the program is sl+. I wonder this may be because my jax installation is not correct? Below is jax and jaxlib version :

jax 0.4.12 jaxlib 0.4.12+cuda12.cudnn89

Because cuda version is 12.1, use jax 0.4.12. And install jax use pip install jax==0.4.12, install jaxlib use pip install ./jaxlib-0.4.12+cuda12.cudnn89-cp311-cp311-manylinux2014_x86_64.whl

whopezx commented 2 months ago

I try in nvdia A100, the program can run correctly. On A30 may have some setting, stop the program.

jsspencer commented 2 months ago

You might be hitting an OOM issue or an XLA compilation issue, hard to know without careful debugging on your own setup.