aleximmer / Laplace

Laplace approximations for Deep Learning.
https://aleximmer.github.io/Laplace
MIT License
472 stars 73 forks source link

Efficient, universal, standalone Jacobian backend #203

Open wiseodd opened 4 months ago

wiseodd commented 4 months ago

We need a Jacobian backend that is:

aleximmer commented 4 months ago

In my experience, asdl works well for differentiable Jacobian computation. However, I would hope that curvlinops eventually offers such functionality as well since asdl is not actively maintained and, specifically, the differentiability aspect is only available in branches that have not been merged. It also is not present anymore in asdfghjkl, where differentiability broke for some reason I couldn't figure out.

wiseodd commented 4 months ago

Discussion result with @f-dangel

Let J(x)SJ(x)^T: \R^k \to R^k, i.e. (J(x)SJ(x)^T)(v) for v \in \R^k.

Computation with vmap (for small K or want differentiability)

Then vmap(J(x)SJ(x)^T)(I) gives us the var(f(x)). Note that this is vmap over num classes, unlike the current Jacobian implementation.

def model_fn_params_only(params_dict, buffers_dict):
      out = torch.func.functional_call(self.model, (params_dict, buffers_dict), x)
      return out, out

_, vjpfunc = torch.func.vjp(model_fn_params_only, self.params_dict, has_aux=True)
_, jvpfunc = torch.func.jvp(model_fn_params_only, self.params_dict, has_aux=True)

def JSJT(v):
    v = vjpfunc(v)
    v = S @ v
    v = jvpfunc(v)
    return v

func_var = vmap(JSJT)(I)

Computation with for-loop (for large K & don't need differentiability)

func_var = torch.stack([JSJT(v).detach() for v in I])

If only care about diag of func_var:

func_var = torch.stack([JSJT(v).detach()[i] for (i, v) in enumerate(I)])

Sampling

For sampling f(x) this can be done cheaply (see Laurence paper). For LLMs, this might be better because we don't really care about the explicit J(x)SJ(x), but only the resulting \int softmax(f(x)) N(f(x) | f_\theta(x), J(x)SJ(x)^T) df(x) I.e. the costs is now wrt. number of samples, instead of K.

Further things

All of them can be optimized further depending on the form of S