aleximmer / Laplace

Laplace approximations for Deep Learning.
https://aleximmer.github.io/Laplace
MIT License
472 stars 73 forks source link

Feature Request: Allow sampling of log probs and logits for Likelihood.CLASSIFICATION #241

Open BlackHC opened 2 months ago

BlackHC commented 2 months ago

The code in _glm_predictive_samples always applies torch.softmax to the results under classification.

For numerical stability supporting torch.log_softmax here would be helpful. Similarly, it would be helpful if there was an easy way to obtain the logits without having to change self.likelihood intermittently.

Thanks,\ Andreas

wiseodd commented 2 months ago

Thanks for the input, Andreas! I wonder if something like this works for your case:

def _glm_predictive_samples(
    self,
    f_mu: torch.Tensor,
    f_var: torch.Tensor,
+   link_function: Optional[Callable[[torch.Tensor], torch.Tensor]]
    n_samples: int,
    diagonal_output: bool = False,
    generator: torch.Generator | None = None,
) -> torch.Tensor:

Where

@aleximmer, @runame feel free to chime in.

Looking for feedback before implementing this.

BlackHC commented 2 months ago

Aww, yeah, that would be great! It would cover all my use cases and provide a nice extensible interface.

runame commented 2 months ago

Sounds like a good improvement!