aleximmer / Laplace

Laplace approximations for Deep Learning.
https://aleximmer.github.io/Laplace
MIT License
472 stars 73 forks source link

Questions about defining a subnet by disabling gradients #217

Closed elcorto closed 3 months ago

elcorto commented 3 months ago

I have some questions regarding the method of defining a subnetwork by disabling grads, termed here "subnet-by-disable-grads". Please excuse the long-ish text below. If there is another place where questions rather than bug reports should be discussed, please let me know (e.g. github discussions).

From #213:

I have to admit that we don't pay too much attention to the SubnetLaplace implementation anymore since the same can be done more intuitively by switching off the unwanted subset's grads. The benefit of the latter is that the Jacobian for the linearized predictive will only be computed over the selected subset. Meanwhile, the former computes the full Jacobian first and then slices it using the subnet indices.

I guess you are referring to this part of the docs and these tests. I assumed that this is just another way of selecting a subset of weights, but that is not the case as it seems.

When using SubnetLaplace, I do see a speedup when doing la.fit(...) and in hyperparameter optimization. So I assume the calculation of the full Jacobian you refer to happens during calls to _glm_predictive_distribution(), where one won't save compute and memory?

Is this also true for subset_of_weights="last_layer", i.e. LLLaplace? I guess not since the code looks as if it cuts out the last layer first and then operates only on that. If so, then what would be the benefit of using "subnet-by-disable-grads" over LLLaplace?

There is laplace.utils.subnetmask.LastLayerSubnetMask, but I guess that is only for testing things given that there are two other methods for defining a last-layer subnet (LLLaplace and subnet-by-disable-grads).

The examples for "subnet-by-disable-grads" I have seen so far seem to focus on last layer, or more general disabling grads layer-wise. Since one cannot use any of the Largest*SubnetMask or RandomSubnetMask to disable individual weight's grads in a tensor via a parameters_to_vector() + vector_to_parameters() round trip, the method seems to be limited to doing last_layer-type subnet selection. Is this correct?

The test in #216 checks that in case of using SubnetLaplace, only non-fixed parameters vary across samples. I think this behavior is unique to SubnetLaplace-based selections, since the sample() method of LLLaplace and its subclasses return only samples of the last layer weights, which is corrected for in _nn_predictive_samples() as far as I can see. I wonder if sample() in case of subnet-by-disable-grads is aware of disabled grads, since all methods which are not in subclasses of SubnetLaplace or LLLaplace seem to generate "just" vectors of length n_params such that parameter samples would also vary for fixed parameters.

Thanks.

wiseodd commented 3 months ago

SubnetLaplace vs disabling grads

The way I see it, disabling gradient is another way to implement the subnet Laplace. You can just switch off the grad of the parameter you don't want (instead of providing a subnet mask), and the backend will automatically compute the Hessian and Jacobian only for the parameters which have require_grad = True. This is why Laplace is applicable to LLM at all (switch off grads other than the LoRA params' and do Laplace as usual) since it's basically emulating SubnetLaplace.

The main problem of the current implementation of SubnetLaplace is that the Jacobian computation for the GLM predictive is done as I said in #213:

https://github.com/aleximmer/Laplace/blob/553cf7c4e7b5bded8760c3244ed0ff2dbb11b191/laplace/curvature/asdl.py#L92-L93

It computes the full Jacobian (very large!) and just slicing it with the subnet mask.

Of course, for more sophisticated subnet selection, SubnetLaplace is still desirable due to the existence of many helper functions, see https://github.com/aleximmer/Laplace/blob/main/laplace/utils/subnetmask.py But this is quite orthogonal to the implementation to SubnetLaplace itself, i.e. one can implement it by taking a subnet mask and switching off the grad of the params not in the mask.

Last-layer Laplace

For last-layer Laplace, it's still preferable to use LLLaplace since it's highly optimized. E.g. the Jacobian is computed in a special way unlike SubnetLaplace. The example bit you referred to is just for intuition purpose :)

Sampling

The sample method of Laplace takes into account the disabled gradients. First, the parameters held by Laplace and hence self.n_params are just those with requires_grad = True:

https://github.com/aleximmer/Laplace/blob/553cf7c4e7b5bded8760c3244ed0ff2dbb11b191/laplace/baselaplace.py#L114-L123

Then, in self.sample, Laplace generate sample for those self.n_params only, e.g.:

https://github.com/aleximmer/Laplace/blob/553cf7c4e7b5bded8760c3244ed0ff2dbb11b191/laplace/baselaplace.py#L1495-L1503

Then in self._nn_predictive_samples Laplace simply does

https://github.com/aleximmer/Laplace/blob/553cf7c4e7b5bded8760c3244ed0ff2dbb11b191/laplace/baselaplace.py#L1168

Note that self.params is a reference to the subset of model.parameters(). Calling the above equals updating that subset of params with the sampled params.

wiseodd commented 3 months ago

I might have missed some of your questions. So please just repeat below in that case, or if you have any follow up questions.

elcorto commented 3 months ago

Thanks for the detailed answer, that's highly appreciated.

The explanation of last layer by disabling grads is very helpful. I wasn't aware of the fact that self.params only contains the active ones. Now it is clear why sample() produces vectors of active params only, as in the LLLaplace case. This is different from SubnetLaplace, which always produces samples of active + fixed params (what the test in #216 checks for) .

SubnetLaplace vs disabling grads

[...]

Of course, for more sophisticated subnet selection, SubnetLaplace is still desirable due to the existence of many helper functions, see https://github.com/aleximmer/Laplace/blob/main/laplace/utils/subnetmask.py But this is quite orthogonal to the implementation to SubnetLaplace itself, i.e. one can implement it by taking a subnet mask and switching off the grad of the params not in the mask.

That's a good point. However, it looks as if this doesn't work for helpers that operate on individual weights across param tensors, for instance:

import torch as T
from laplace.utils import subnetmask as su

model = T.nn.Sequential(T.nn.Linear(2, 20), T.nn.ReLU(), T.nn.Linear(20, 3))

params = T.nn.utils.parameters_to_vector(model.parameters())
subnetmask = su.LargestMagnitudeSubnetMask(
    model=model, n_params_subnet=int(len(params) * 0.8)
)
fixed_mask = T.ones(len(params), dtype=bool)
fixed_mask[subnetmask.select()] = False

# RuntimeError: you can only change requires_grad flags of leaf variables.
params[fixed_mask].requires_grad = False

This is because one can, from my understanding, only disable grads on a tensor level and not for single entries.

There is probably an obvious solution, but at the moment I can't think of any. I'd appreciate any hints here. Thanks.

elcorto commented 3 months ago

Another question I wanted to ask again is: Since LLLaplace and disabling all but the last layer's grads seem to do the same in effect, which of the methods would you recommend?

wiseodd commented 3 months ago

Good point on disabling grad on the "tensor" level (or more accurately, on the torch.nn.Parameter level). In this case, SubnetLaplace is more flexible.

Since LLLaplace and disabling all but the last layer's grads seem to do the same in effect, which of the methods would you recommend?

I haven't tested this in-depth, but my hunch is that last-layer Laplace via disabling grads is more universally applicable than LLLaplace. Notice that in the implementation of LLLaplace we have to do many extra steps like creating FeatureExtractor, inferring the last-layer, dealing with feature reduction, etc. When done by disabling grads, one doesn't have to worry about them.

However, again, there are some trade-offs here. For instance, LLLaplace enables last-layer Laplace specific tricks such as in the GLM predictive https://github.com/aleximmer/Laplace/pull/145

elcorto commented 3 months ago

Very valuable feedback again, thanks a bunch. So to summarize

Is this about right?

I feel that this list could go into the documentation. I'm happy to add this somewhere. If you think this is useful, let me know and I'll keep this issue open until then. However if things are being tested and in flux ATM such that documenting it is not worth it, then I'll close this issue and keep it as a temporary reference. Thanks.

wiseodd commented 3 months ago

I agree that documentation would be good. What would be useful is to distinguish disable-grad & SubnetLaplace, LLLaplace by applications. For example:

The current documentation website with its single-page layout is not very good UX-wise. So if you want to document this discussion, feel free to do so in the README.

I just rewrote my personal site & blog using Astro and had a very good experience. I might think of finding a way to migrate Laplace's documentation to Starlight soon.