Closed elcorto closed 3 months ago
SubnetLaplace
vs disabling gradsThe way I see it, disabling gradient is another way to implement the subnet Laplace. You can just switch off the grad of the parameter you don't want (instead of providing a subnet mask), and the backend will automatically compute the Hessian and Jacobian only for the parameters which have require_grad = True
. This is why Laplace is applicable to LLM at all (switch off grads other than the LoRA params' and do Laplace as usual) since it's basically emulating SubnetLaplace
.
The main problem of the current implementation of SubnetLaplace
is that the Jacobian computation for the GLM predictive is done as I said in #213:
It computes the full Jacobian (very large!) and just slicing it with the subnet mask.
Of course, for more sophisticated subnet selection, SubnetLaplace
is still desirable due to the existence of many helper functions, see https://github.com/aleximmer/Laplace/blob/main/laplace/utils/subnetmask.py
But this is quite orthogonal to the implementation to SubnetLaplace
itself, i.e. one can implement it by taking a subnet mask and switching off the grad of the params not in the mask.
For last-layer Laplace, it's still preferable to use LLLaplace
since it's highly optimized. E.g. the Jacobian is computed in a special way unlike SubnetLaplace
. The example bit you referred to is just for intuition purpose :)
The sample
method of Laplace takes into account the disabled gradients. First, the parameters held by Laplace and hence self.n_params
are just those with requires_grad = True
:
Then, in self.sample
, Laplace generate sample for those self.n_params
only, e.g.:
Then in self._nn_predictive_samples
Laplace simply does
Note that self.params
is a reference to the subset of model.parameters()
. Calling the above equals updating that subset of params with the sampled params.
I might have missed some of your questions. So please just repeat below in that case, or if you have any follow up questions.
Thanks for the detailed answer, that's highly appreciated.
The explanation of last layer by disabling grads is very helpful. I wasn't aware of the fact that self.params
only contains the active ones. Now it is clear why sample()
produces vectors of active params only, as in the LLLaplace
case. This is different from SubnetLaplace
, which always produces samples of active + fixed params (what the test in #216 checks for) .
SubnetLaplace
vs disabling grads[...]
Of course, for more sophisticated subnet selection,
SubnetLaplace
is still desirable due to the existence of many helper functions, see https://github.com/aleximmer/Laplace/blob/main/laplace/utils/subnetmask.py But this is quite orthogonal to the implementation toSubnetLaplace
itself, i.e. one can implement it by taking a subnet mask and switching off the grad of the params not in the mask.
That's a good point. However, it looks as if this doesn't work for helpers that operate on individual weights across param tensors, for instance:
import torch as T
from laplace.utils import subnetmask as su
model = T.nn.Sequential(T.nn.Linear(2, 20), T.nn.ReLU(), T.nn.Linear(20, 3))
params = T.nn.utils.parameters_to_vector(model.parameters())
subnetmask = su.LargestMagnitudeSubnetMask(
model=model, n_params_subnet=int(len(params) * 0.8)
)
fixed_mask = T.ones(len(params), dtype=bool)
fixed_mask[subnetmask.select()] = False
# RuntimeError: you can only change requires_grad flags of leaf variables.
params[fixed_mask].requires_grad = False
This is because one can, from my understanding, only disable grads on a tensor level and not for single entries.
There is probably an obvious solution, but at the moment I can't think of any. I'd appreciate any hints here. Thanks.
Another question I wanted to ask again is: Since LLLaplace
and disabling all but the last layer's grads seem to do the same in effect, which of the methods would you recommend?
Good point on disabling grad on the "tensor" level (or more accurately, on the torch.nn.Parameter
level). In this case, SubnetLaplace
is more flexible.
Since LLLaplace and disabling all but the last layer's grads seem to do the same in effect, which of the methods would you recommend?
I haven't tested this in-depth, but my hunch is that last-layer Laplace via disabling grads is more universally applicable than LLLaplace
. Notice that in the implementation of LLLaplace
we have to do many extra steps like creating FeatureExtractor
, inferring the last-layer, dealing with feature reduction, etc. When done by disabling grads, one doesn't have to worry about them.
However, again, there are some trade-offs here. For instance, LLLaplace
enables last-layer Laplace specific tricks such as in the GLM predictive https://github.com/aleximmer/Laplace/pull/145
Very valuable feedback again, thanks a bunch. So to summarize
SubnetLaplace
since it avoids calculating full JacobiansLLLaplace
and laplace.utils.subnetmask.LastLayerSubnetMask
(the latter is probably only for testing purposes)SubnetLaplace
offers such as Largest*SubnetMask
or RandomSubnetMask
SubnetLaplace
, LLLaplace
), sample()
returns vectors of different lengths, but this is always corrected for in _nn_predictive_samples()
LLLaplace
offers improved performance (#145)Is this about right?
I feel that this list could go into the documentation. I'm happy to add this somewhere. If you think this is useful, let me know and I'll keep this issue open until then. However if things are being tested and in flux ATM such that documenting it is not worth it, then I'll close this issue and keep it as a temporary reference. Thanks.
I agree that documentation would be good. What would be useful is to distinguish disable-grad & SubnetLaplace, LLLaplace
by applications. For example:
The current documentation website with its single-page layout is not very good UX-wise. So if you want to document this discussion, feel free to do so in the README.
I just rewrote my personal site & blog using Astro and had a very good experience. I might think of finding a way to migrate Laplace's documentation to Starlight soon.
I have some questions regarding the method of defining a subnetwork by disabling grads, termed here "subnet-by-disable-grads". Please excuse the long-ish text below. If there is another place where questions rather than bug reports should be discussed, please let me know (e.g. github discussions).
From #213:
I guess you are referring to this part of the docs and these tests. I assumed that this is just another way of selecting a subset of weights, but that is not the case as it seems.
When using
SubnetLaplace
, I do see a speedup when doingla.fit(...)
and in hyperparameter optimization. So I assume the calculation of the full Jacobian you refer to happens during calls to_glm_predictive_distribution()
, where one won't save compute and memory?Is this also true for
subset_of_weights="last_layer"
, i.e.LLLaplace
? I guess not since the code looks as if it cuts out the last layer first and then operates only on that. If so, then what would be the benefit of using "subnet-by-disable-grads" overLLLaplace
?There is
laplace.utils.subnetmask.LastLayerSubnetMask
, but I guess that is only for testing things given that there are two other methods for defining a last-layer subnet (LLLaplace
and subnet-by-disable-grads).The examples for "subnet-by-disable-grads" I have seen so far seem to focus on last layer, or more general disabling grads layer-wise. Since one cannot use any of the
Largest*SubnetMask
orRandomSubnetMask
to disable individual weight's grads in a tensor via aparameters_to_vector()
+vector_to_parameters()
round trip, the method seems to be limited to doinglast_layer
-type subnet selection. Is this correct?The test in #216 checks that in case of using
SubnetLaplace
, only non-fixed parameters vary across samples. I think this behavior is unique toSubnetLaplace
-based selections, since thesample()
method ofLLLaplace
and its subclasses return only samples of the last layer weights, which is corrected for in_nn_predictive_samples()
as far as I can see. I wonder ifsample()
in case of subnet-by-disable-grads is aware of disabled grads, since all methods which are not in subclasses ofSubnetLaplace
orLLLaplace
seem to generate "just" vectors of lengthn_params
such that parameter samples would also vary for fixed parameters.Thanks.