AdamCobb / hamiltorch

PyTorch-based library for Riemannian Manifold Hamiltonian Monte Carlo (RMHMC) and inference in Bayesian neural networks
BSD 2-Clause "Simplified" License
426 stars 63 forks source link

log_prob() with additional arguments #26

Open aaschwanden opened 1 year ago

aaschwanden commented 1 year ago

My log_prob function takes additional arguments besides variable $p$ that I want to estimate:

# Define the Log Probability (sum of log likelihood and log prior)
def log_prob(p: torch.Tensor, time: torch.Tensor, obs: torch.Tensor):
    return log_likelihood(p, time, obs) + log_prior(p)

# Define Log Prior Function - Gaussian Distribution
def log_prior(p: torch.Tensor):
    p_size = p.size()[0]
    return torch.distributions.MultivariateNormal(torch.zeros(p_size), torch.eye(p_size) * 25, validate_args=False).log_prob(p).sum()

# Define the Log Likelihood P(x|p) - Bernoulli distribution
def log_likelihood(p: torch.Tensor, time: torch.Tensor, obs: torch.Tensor):
    prob = 1.0 / (1.0 + torch.exp(p[0] + p[1] * time))
    return torch.distributions.Bernoulli(prob, validate_args=False).log_prob(obs).sum()

Is there a way to call hamiltorch.sample with additional arguments?

Calling params_hmc = hamiltorch.sample(log_prob_func=log_prob, ...)

unsurprisingly results in

TypeError: log_prob() missing 2 required positional arguments: 'time' and 'obs'

as the documentation says

log_prob_func : function
    A log_prob_func must take a 1-d vector of length equal to the number of parameters that are being sampled.

I've also tried

params_hmc = hamiltorch.sample(log_prob_func=log_prob(time, obs), ...)

and

params_hmc = hamiltorch.sample(log_prob_func=log_prob(time=time, obs=obs), ...)

both of which throw errors.