sbi-dev / sbi

Simulation-based inference toolkit
https://sbi-dev.github.io/sbi/
Apache License 2.0
579 stars 145 forks source link

Crash with SNPE "mdn" after multi-rounds: NaN/Inf present in posterior eval. #786

Closed jecampagne closed 2 months ago

jecampagne commented 1 year ago

Hi, I experience this Nan/Inf crash (sorry but it was after a long rounds run)

Round[5]: density_estimator training
Using SNPE-C with atomic loss
 Training neural network. Epochs trained: 17Traceback (most recent call last):
  File "DESY1_sbi_SNPE_multi.py", line 375, in <module>
    do_multi_pass(num_simu = 100_000, num_rounds = 10, tag="mdn_default",
  File "DESY1_sbi_SNPE_multi.py", line 321, in do_multi_pass
    density_estimator = inference.append_simulations(
  File "/anaconda3/envs/sbi/lib/python3.8/site-packages/sbi/inference/snpe/snpe_c.py", line 179, in train
    return super().train(**kwargs)
  File "/anaconda3/envs/sbi/lib/python3.8/site-packages/sbi/inference/snpe/snpe_base.py", line 327, in train
    train_losses = self._loss(
  File "/anaconda3/envs/sbi/lib/python3.8/site-packages/sbi/inference/snpe/snpe_base.py", line 548, in _loss
    log_prob = self._log_prob_proposal_posterior(theta, x, masks, proposal)
  File "/anaconda3/envs/sbi/lib/python3.8/site-packages/sbi/inference/snpe/snpe_c.py", line 280, in _log_prob_proposal_posterior
    return self._log_prob_proposal_posterior_atomic(theta, x, masks)
  File "/anaconda3/envs/sbi/lib/python3.8/site-packages/sbi/inference/snpe/snpe_c.py", line 331, in _log_prob_proposal_posterior_atomic
    utils.assert_all_finite(log_prob_posterior, "posterior eval")
  File "/anaconda3/envs/sbi/lib/python3.8/site-packages/sbi/utils/torchutils.py", line 372, in assert_all_finite
    assert torch.isfinite(quantity).all(), msg
AssertionError: NaN/Inf present in posterior eval.

It's a pity as the multi rounds contours were slowly converging to the likelihood NUTS sampling which serves as the "true" contours.

Any idea to fix the pb?

michaeldeistler commented 1 year ago

Hi, thanks for reporting. This seems related to this. What is your version of pyknos?

jecampagne commented 1 year ago

pyknos 0.15.1

Looking at the issue you pointed, I mention that I have dim=21 and use 100,000 simul per rounds. The crash occurs only during the 6th round but the 5th round contours look nice.

janfb commented 2 months ago

closing this due to inactivity. Please feel free to re-open if this is still relevant.