normal-computing / posteriors

Uncertainty quantification with PyTorch
https://normal-computing.github.io/posteriors/
Apache License 2.0
278 stars 12 forks source link

`linearized_forward_diag` API #57

Open SamDuffield opened 2 months ago

SamDuffield commented 2 months ago

Currently we have an API like

vals, chol, aux = linearized_forward_diag(f, params, batch, sd_diag)

but perhaps and API like

vals, chol, aux = linearized_forward_diag(f, sd_diag)(params, batch)

might be cleaner as it provides a new function that retains the required signature of f. It's also better fitting with the torch.func API

SamDuffield commented 2 months ago

Additionally considerations: