sbi-dev / sbi

Simulation-based inference toolkit
https://sbi-dev.github.io/sbi/
Apache License 2.0
578 stars 145 forks source link

Distribution.log_prob() got an unexpected keyword argument 'condition' #1146

Closed MolinAlexei closed 5 months ago

MolinAlexei commented 5 months ago

I am running the following lines of code :


with open("file_with_estimator.pkl", "rb") as handle:
    density_estimator_SNLE = pickle.load(handle)

prior = utils.BoxUniform(low=torch.asarray([0,2,150]), 
                          high=torch.asarray([8.8,3,350]))
inferenceSNLE = SNLE(prior=prior)

posterior = inferenceSNLE.build_posterior(density_estimator_SNLE)

x = np.load('mydata.npy')

posterior.set_default_x(x)

posterior.potential(
         torch.from_numpy( 
                       np.zeros( (1,3) ) 
                             )
)

And I get the following error :

---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
Cell In[104], line 1
----> 1 posterior.potential(torch.from_numpy(np.zeros((1,3))))

File ~/miniconda3/envs/myenv/lib/python3.10/site-packages/sbi/inference/posteriors/base_posterior.py:101, in NeuralPosterior.potential(self, theta, x, track_gradients)
     98 self.potential_fn.set_x(self._x_else_default_x(x))
    100 theta = ensure_theta_batched(torch.as_tensor(theta))
--> 101 return self.potential_fn(
    102     theta.to(self._device), track_gradients=track_gradients
    103 )

File ~/miniconda3/envs/myenv/lib/python3.10/site-packages/sbi/inference/potentials/likelihood_based_potential.py:95, in LikelihoodBasedPotential.__call__(self, theta, track_gradients)
     84 r"""Returns the potential $\log(p(x_o|\theta)p(\theta))$.
     85 
     86 Args:
   (...)
     91     The potential $\log(p(x_o|\theta)p(\theta))$.
     92 """
     94 # Calculate likelihood over trials and in one batch.
---> 95 log_likelihood_trial_sum = _log_likelihoods_over_trials(
     96     x=self.x_o,
     97     theta=theta.to(self.device),
     98     estimator=self.likelihood_estimator,
     99     track_gradients=track_gradients,
    100 )
    102 return log_likelihood_trial_sum + self.prior.log_prob(theta)

File ~/miniconda3/envs/myenv/lib/python3.10/site-packages/sbi/inference/potentials/likelihood_based_potential.py:151, in _log_likelihoods_over_trials(x, theta, estimator, track_gradients)
    149 # Calculate likelihood in one batch.
    150 with torch.set_grad_enabled(track_gradients):
--> 151     log_likelihood_trial_batch = estimator.log_prob(x, condition=theta)
    152     # Sum over trial-log likelihoods.
    153     log_likelihood_trial_sum = log_likelihood_trial_batch.sum(0)
TypeError: Distribution.log_prob() got an unexpected keyword argument 'condition'

I have sbi version 0.22.0

michaeldeistler commented 5 months ago

Did you install sbi from pypi with pip install sbi or from the most recent github version?

MolinAlexei commented 5 months ago

I did pip install git+https://github.com/sbi-dev/sbi.git somewhere around one and a half weeks ago (sorry I don't have the exact date)

michaeldeistler commented 5 months ago

Okay, thanks, we will have a look!

michaeldeistler commented 5 months ago

Actually, one more question: You are loading the density estimator

with open("file_with_estimator.pkl", "rb") as handle:
    density_estimator_SNLE = pickle.load(handle)

Did you create this density estimator under an older version of sbi?

MolinAlexei commented 5 months ago

Yes, I did

MolinAlexei commented 5 months ago

I'm guessing that I probably need to create a density estimator with the new version to then use this feature ?

michaeldeistler commented 5 months ago

I okay, then that's the reason. We have changed the density_estimators in sbi and you cannot use "old" density estimators under the newest sbi version (from github).

I recommend to install sbi from pypi with pip install sbi.

MolinAlexei commented 5 months ago

Okay, thank you ! I did this install this way because I need to access the results of an embedding net in another test I'm doing (through learned_summary_stats = trained_estimator.embedding_net(x))

michaeldeistler commented 5 months ago

Ah, okay. Yes, unfortunately you will have to retrain the density estimator here.