Closed wiseodd closed 1 year ago
Looks good to me. A few comments/questions:
joint=True
asserting 1) equal marginal predictive and 2) right shape?create_graph=True
apart from retain_graph=True
.@AlexImmer
Could you add a small test case where the predictives are currently tested but with the setting
joint=True
asserting 1) equal marginal predictive and 2) right shape?
- To differentiate the predictive uncertainty as well, which depends on Jacobians, the Jacobians would have to be differentiable. Don't you need that? This would require
create_graph=True
apart fromretain_graph=True
.
Done in bc50025.
- How well does the Laplace model work vs. the GP on these toy problems?
Seems to work reasonably:
agustinuskristiadi@akmbp ~/Projects/Laplace [13:37:28]
(base) > $ python examples/bayesopt_example.py --exp_len 50 --test_func branin --acqf qEI [±detach ●]
Test Function: branin
-------------------------------------
[BNN-LA, MSE = 85.071; Best f(x) = 0.846, curr f(x) = 3.055]: 100%|██████████████████| 50/50 [01:58<00:00, 2.37s/it]
[GP, MSE = 2081.004; Best f(x) = 0.510, curr f(x) = 12.377]: 100%|███████████████████| 50/50 [00:13<00:00, 3.83it/s]
agustinuskristiadi@akmbp ~/Projects/Laplace [13:39:57]
(base) > $ python examples/bayesopt_example.py --exp_len 50 --test_func branin --acqf EI [±detach ●]
Test Function: branin
-------------------------------------
[BNN-LA, MSE = 183.844; Best f(x) = 0.939, curr f(x) = 125.097]: 100%|███████████████| 50/50 [00:59<00:00, 1.19s/it]
[GP, MSE = 4422.465; Best f(x) = 1.199, curr f(x) = 2.495]: 100%|████████████████████| 50/50 [00:06<00:00, 8.30it/s]
@runame
Are there cases where you might want to differentiate through NN samples?
There is, e.g. in Thompson sampling like https://arxiv.org/abs/1706.01825. I added the option.
Also, what about LowRankLaplace
There's already functional_covariance
function for LowRankLaplace
here https://github.com/AlexImmer/Laplace/blob/6e50ab4e774b2834b11aa5e984bdd1c069255b49/laplace/baselaplace.py#L1016. Do I need to do something else?
LLLaplace?
@AlexImmer, @runame: For making the last-layer Jacobians optionally backpropable, I need to do the following. I'm not sure how to make forward_with_features
communicate with the hook, though.
def forward_with_features(self, x: torch.Tensor, enable_backprop=True)
"""
Add an option `enable_backprop`
"""
pass
def _get_hook(self, name: str) -> Callable:
def hook(_, input, __):
# Depending on the above `enable_backprop`, we detach
self._features[name] = input[0].detach()
return hook
Should I make enable_backprop
a class-level attribute? (In this case I also need to do the same with BackPackInterface
.)
Also, what about LowRankLaplace
There's already functional_covariance function for LowRankLaplace here
Laplace/laplace/baselaplace.py
Line 1016 in 6e50ab4 def functional_covariance(self, Jacs): . Do I need to do something else?
I was asking because of this line: you detached self.mean
for all ParametricLaplace
subclasses, but not for LowRankLaplace
.
For making the last-layer Jacobians optionally backpropable, I need to do the following. (...)
Yeah I think a class-level attribute is probably the best solution. Also, maybe it makes sense to consistently name the argument which decides wether to detach or not; in some places it's called detach
and here enable_backprop
.
@AlexImmer, @runame how do you actually enable backprop in ASDL's jacobian here? (I'm unfamiliar with its API.)
Yeah I think a class-level attribute is probably the best solution. Also, maybe it makes sense to consistently name the argument which decides wether to detach or not; in some places it's called detach and here enable_backprop.
See 7726b3d
@AlexImmer @runame: There's no more action for me in this PR, so you can review it again.
As for enabling all these stuffs in the ASDL backend, how about we invite Michael Aerni to contribute? If he has the code already, I can also do it---in this case, he simply needs to push his changes to a new branch. In any case, let's do this in a separate PR.
Fixed @runame
Changes:
detach()
in GLM predictive so that one can backprop fromf(x)
tox
, in particular for optimizing acquisition functions.functional_covariance
method for GLM predictive that returns the(nk,)
mean and the(nk, nk)
covariance overn
input points andk
outputs.