aleximmer / Laplace

Laplace approximations for Deep Learning.
https://aleximmer.github.io/Laplace
MIT License
432 stars 61 forks source link

Replace hardcoded class index with `logit_class_dim` argument #177

Open wiseodd opened 2 months ago

wiseodd commented 2 months ago

Closes #163

Please wait until #144 is merged.

wiseodd commented 2 months ago

Still WIP test cases

wiseodd commented 2 months ago

Ready to review!

Discussion points:

  1. Do BackPACK, ASDL, Asdfghjkl even support multiple output dims? That is, if we flatten logits = logits.view(-1, logits.size(logit_class_dim)), do they even compute the correct quantities?
  2. Curvlinops and torch.func interfaces always assume logit_class_dim = 1. Do we want to make them respect logit_class_dim? I don't think flattening as above is the correct approach, right?
wiseodd commented 2 weeks ago

Merged with main and ready for review!

wiseodd commented 1 day ago

It seems more complicated than anticipated. This PR is useful for models with image outputs like diffusion models.

Considering v0.2 is all about LLMs, let's defer this to v0.3!

runame commented 1 day ago

It seems more complicated than anticipated. This PR is useful for models with image outputs like diffusion models.

Considering v0.2 is all about LLMs, let's defer this to v0.3!

Ok, for now maybe we can add a note in the README and the docstring that clearly states how multi dim outputs are handled?

wiseodd commented 1 day ago

So what do you have in mind regarding the wording? Something like this in README.md?

## Caveats

- Currently, this library always assumes that the model has an 
  output tensor of shape `(batch_size, ..., n_classes)`, so in 
  the case of image outputs, you need to rearrange from NCHW to NHWC.
runame commented 1 day ago

So what do you have in mind regarding the wording? Something like this in README.md?

Yes, this is exactly what I was thinking!