sbi-dev / sbi

Simulation-based inference toolkit
https://sbi-dev.github.io/sbi/
Apache License 2.0
546 stars 133 forks source link

Using MCMCPosterior gives 'Sequential' object has no attribute 'log_prob' #1173

Closed almoyer closed 2 weeks ago

almoyer commented 3 weeks ago

Describe the bug Using SNRE_A and the sampler MCMCPosterior for a simple model I cannot sample the posterior as this error occurs : AttributeError: 'Sequential' object has no attribute 'log_prob'

To Reproduce

  1. SBI Version : 0.22.0, Python version : 3.10.12 Code example:

`import torch from sbi.inference import SNRE_A import sbi.inference from sbi.inference import likelihood_estimator_based_potential, MCMCPosterior

num_dim = 2 prior = torch.distributions.MultivariateNormal(torch.zeros(num_dim), torch.eye(num_dim)) theta = prior.sample((1000,)) x = theta + torch.randn((1000, num_dim)) x_o = torch.randn((1, num_dim)) inference = SNRE_A(prior=prior) lkl_estimator= inference.append_simulations(theta, x).train()

potential_fn, parameter_transform = likelihood_estimator_based_potential( lkl_estimator, prior, x_o ) posterior = MCMCPosterior( potential_fn, proposal=prior, theta_transform=parameter_transform, warmup_steps=10 )

posterior.sample((1000,),x=x_o) ` Error message:

`AttributeError Traceback (most recent call last)

in 19 ) 20 ---> 21 posterior.sample((1000,),x=x_o) ~/.local/lib/python3.10/site-packages/sbi/inference/posteriors/mcmc_posterior.py in sample(self, sample_shape, x, method, thin, warmup_steps, num_chains, init_strategy, init_strategy_parameters, init_strategy_num_candidates, mcmc_parameters, mcmc_method, sample_with, num_workers, show_progress_bars) 273 self.potential_ = self._prepare_potential(method) # type: ignore 274 --> 275 initial_params = self._get_initial_params( 276 init_strategy, # type: ignore 277 num_chains, # type: ignore ~/.local/lib/python3.10/site-packages/sbi/inference/posteriors/mcmc_posterior.py in _get_initial_params(self, init_strategy, num_chains, num_workers, show_progress_bars, **kwargs) 420 else: 421 initial_params = torch.cat( --> 422 [init_fn() for _ in range(num_chains)] # type: ignore 423 ) 424 ~/.local/lib/python3.10/site-packages/sbi/inference/posteriors/mcmc_posterior.py in (.0) 420 else: 421 initial_params = torch.cat( --> 422 [init_fn() for _ in range(num_chains)] # type: ignore 423 ) 424 ~/.local/lib/python3.10/site-packages/sbi/inference/posteriors/mcmc_posterior.py in () 353 ) 354 elif init_strategy == "resample": --> 355 return lambda: resample_given_potential_fn( 356 proposal, potential_fn, transform=transform, **kwargs 357 ) ~/.local/lib/python3.10/site-packages/sbi/samplers/mcmc/init_strategy.py in resample_given_potential_fn(proposal, potential_fn, transform, num_candidate_samples, num_batches, **kwargs) 99 batch_draws = proposal.sample((num_candidate_samples,)).detach() 100 init_param_candidates.append(batch_draws) --> 101 log_weights.append(potential_fn(batch_draws).detach()) 102 log_weights = torch.cat(log_weights) 103 init_param_candidates = torch.cat(init_param_candidates) ~/.local/lib/python3.10/site-packages/sbi/inference/potentials/likelihood_based_potential.py in __call__(self, theta, track_gradients) 90 91 # Calculate likelihood over trials and in one batch. ---> 92 log_likelihood_trial_sum = _log_likelihoods_over_trials( 93 x=self.x_o, 94 theta=theta.to(self.device), ~/.local/lib/python3.10/site-packages/sbi/inference/potentials/likelihood_based_potential.py in _log_likelihoods_over_trials(x, theta, net, track_gradients) 138 # Calculate likelihood in one batch. 139 with torch.set_grad_enabled(track_gradients): --> 140 log_likelihood_trial_batch = net.log_prob(x_repeated, theta_repeated) 141 # Reshape to (x-trials x parameters), sum over trial-log likelihoods. 142 log_likelihood_trial_sum = log_likelihood_trial_batch.reshape( ~/.local/lib/python3.10/site-packages/torch/nn/modules/module.py in __getattr__(self, name) 1707 if name in modules: 1708 return modules[name] -> 1709 raise AttributeError(f"'{type(self).__name__}' object has no attribute '{name}'") 1710 1711 def __setattr__(self, name: str, value: Union[Tensor, 'Module']) -> None: AttributeError: 'Sequential' object has no attribute 'log_prob'` **Expected behavior** If I don't use the function MCMCPosterior and just call 'posterior = inference.build_posterior(lkl_estimator)' and then call the same function 'posterior.sample((1000,),x=x_o)' I obtain a sample of my distribution, but I wish to use the function MCMCPosterior. **Additional context** Looking at your changelog, it seems that a similar issue was solved after version 0.11.0 but I couldn't find the answer to my problem there
michaeldeistler commented 3 weeks ago

Hi there, for SNRE you have to use ratio_estimator_based_potential, not likelihood_estimator_based_potential.

janfb commented 2 weeks ago

Changing your code to this would work:

import torch

from sbi.inference import SNRE_A, MCMCPosterior, ratio_estimator_based_potential

num_dim = 2
prior = torch.distributions.MultivariateNormal(torch.zeros(num_dim), torch.eye(num_dim))
theta = prior.sample((1000,))
x = theta + torch.randn((1000, num_dim))
x_o = torch.randn((1, num_dim))
inference = SNRE_A(prior=prior)
lkl_estimator = inference.append_simulations(theta, x).train()

potential_fn, parameter_transform = ratio_estimator_based_potential(
    lkl_estimator, prior, x_o
)
posterior = MCMCPosterior(
    potential_fn, proposal=prior, theta_transform=parameter_transform, warmup_steps=10
)

posterior.sample((1000,), x=x_o)

or using build_posterior directly instead of creating the potential and transform manually:

posterior = inference.build_posterior(mcmc_parameters=dict(warmup_steps=10))

Feel free to re-open the issue if there is anything unclear.

almoyer commented 2 weeks ago

Hi @janfb and @michaeldeistler ,

Thanks a lot for your answer, I tried running my code using the correct function and it worked! Thanks for those quick answer and the example you gave me.