Open BlackHC opened 2 months ago
Thanks for the input, Andreas! I wonder if something like this works for your case:
def _glm_predictive_samples(
self,
f_mu: torch.Tensor,
f_var: torch.Tensor,
+ link_function: Optional[Callable[[torch.Tensor], torch.Tensor]]
n_samples: int,
diagonal_output: bool = False,
generator: torch.Generator | None = None,
) -> torch.Tensor:
Where
link_function = None
restores the current implementationlink_function = lambda f: f
gets you a sample logitslink_function = functools.partial(torch.log_softmax, dim=-1)
gets you a sample log-softmax.@aleximmer, @runame feel free to chime in.
Looking for feedback before implementing this.
Aww, yeah, that would be great! It would cover all my use cases and provide a nice extensible interface.
Sounds like a good improvement!
The code in
_glm_predictive_samples
always appliestorch.softmax
to the results under classification.For numerical stability supporting
torch.log_softmax
here would be helpful. Similarly, it would be helpful if there was an easy way to obtain the logits without having to change self.likelihood intermittently.Thanks,\ Andreas