normal-computing / posteriors

Uncertainty quantification with PyTorch
https://normal-computing.github.io/posteriors/
Apache License 2.0
314 stars 12 forks source link

More flexible Laplace GGN arguments #91

Closed SamDuffield closed 4 months ago

SamDuffield commented 4 months ago

Currently if you try to use e.g. laplace.dense_ggn with

def forward(p, b):
    x, _ = b
    return torch.func.functional_call(model, p, x), torch.tensor([])

def outer_log_lik(y_pred, b):
    _, y = b
    return torch.distributions.Normal(y_pred, obs_sd).log_prob(y).sum()

It will throw and error

TypeError: forward() got an unexpected keyword argument 'batch'

We can fix this with a lambda rather than partial