aleximmer / Laplace

Laplace approximations for Deep Learning.
https://aleximmer.github.io/Laplace
MIT License
458 stars 71 forks source link

Enable BayesOpt #120

Closed wiseodd closed 1 year ago

wiseodd commented 1 year ago

Changes:

aleximmer commented 1 year ago

Looks good to me. A few comments/questions:

wiseodd commented 1 year ago

@AlexImmer

Could you add a small test case where the predictives are currently tested but with the setting joint=True asserting 1) equal marginal predictive and 2) right shape?

  • To differentiate the predictive uncertainty as well, which depends on Jacobians, the Jacobians would have to be differentiable. Don't you need that? This would require create_graph=True apart from retain_graph=True.

Done in bc50025.

  • How well does the Laplace model work vs. the GP on these toy problems?

Seems to work reasonably:

agustinuskristiadi@akmbp ~/Projects/Laplace                                                               [13:37:28]
(base) > $ python examples/bayesopt_example.py --exp_len 50 --test_func branin --acqf qEI                [±detach ●]

Test Function: branin
-------------------------------------

[BNN-LA, MSE = 85.071; Best f(x) = 0.846, curr f(x) = 3.055]: 100%|██████████████████| 50/50 [01:58<00:00,  2.37s/it]
[GP, MSE = 2081.004; Best f(x) = 0.510, curr f(x) = 12.377]: 100%|███████████████████| 50/50 [00:13<00:00,  3.83it/s]

agustinuskristiadi@akmbp ~/Projects/Laplace                                                               [13:39:57]
(base) > $ python examples/bayesopt_example.py --exp_len 50 --test_func branin --acqf EI                 [±detach ●]

Test Function: branin
-------------------------------------

[BNN-LA, MSE = 183.844; Best f(x) = 0.939, curr f(x) = 125.097]: 100%|███████████████| 50/50 [00:59<00:00,  1.19s/it]
[GP, MSE = 4422.465; Best f(x) = 1.199, curr f(x) = 2.495]: 100%|████████████████████| 50/50 [00:06<00:00,  8.30it/s]

@runame

Are there cases where you might want to differentiate through NN samples?

There is, e.g. in Thompson sampling like https://arxiv.org/abs/1706.01825. I added the option.

Also, what about LowRankLaplace

There's already functional_covariance function for LowRankLaplace here https://github.com/AlexImmer/Laplace/blob/6e50ab4e774b2834b11aa5e984bdd1c069255b49/laplace/baselaplace.py#L1016. Do I need to do something else?

LLLaplace?

f0981a7

wiseodd commented 1 year ago

@AlexImmer, @runame: For making the last-layer Jacobians optionally backpropable, I need to do the following. I'm not sure how to make forward_with_features communicate with the hook, though.

def forward_with_features(self, x: torch.Tensor, enable_backprop=True)
    """
    Add an option `enable_backprop`
    """
    pass

def _get_hook(self, name: str) -> Callable:
    def hook(_, input, __):
        # Depending on the above `enable_backprop`, we detach
        self._features[name] = input[0].detach()
    return hook

Should I make enable_backprop a class-level attribute? (In this case I also need to do the same with BackPackInterface.)

runame commented 1 year ago

Also, what about LowRankLaplace

There's already functional_covariance function for LowRankLaplace here

Laplace/laplace/baselaplace.py

Line 1016 in 6e50ab4 def functional_covariance(self, Jacs): . Do I need to do something else?

I was asking because of this line: you detached self.mean for all ParametricLaplace subclasses, but not for LowRankLaplace.

runame commented 1 year ago

For making the last-layer Jacobians optionally backpropable, I need to do the following. (...)

Yeah I think a class-level attribute is probably the best solution. Also, maybe it makes sense to consistently name the argument which decides wether to detach or not; in some places it's called detach and here enable_backprop.

wiseodd commented 1 year ago

@AlexImmer, @runame how do you actually enable backprop in ASDL's jacobian here? (I'm unfamiliar with its API.)

wiseodd commented 1 year ago

Yeah I think a class-level attribute is probably the best solution. Also, maybe it makes sense to consistently name the argument which decides wether to detach or not; in some places it's called detach and here enable_backprop.

See 7726b3d

wiseodd commented 1 year ago

@AlexImmer @runame: There's no more action for me in this PR, so you can review it again.

As for enabling all these stuffs in the ASDL backend, how about we invite Michael Aerni to contribute? If he has the code already, I can also do it---in this case, he simply needs to push his changes to a new branch. In any case, let's do this in a separate PR.

wiseodd commented 1 year ago

Fixed @runame