Closed SamDuffield closed 4 months ago
Currently if you try to use e.g. laplace.dense_ggn with
laplace.dense_ggn
def forward(p, b): x, _ = b return torch.func.functional_call(model, p, x), torch.tensor([]) def outer_log_lik(y_pred, b): _, y = b return torch.distributions.Normal(y_pred, obs_sd).log_prob(y).sum()
It will throw and error
TypeError: forward() got an unexpected keyword argument 'batch'
We can fix this with a lambda rather than partial
lambda
partial
Currently if you try to use e.g.
laplace.dense_ggn
withIt will throw and error
We can fix this with a
lambda
rather thanpartial