Open wiseodd opened 4 months ago
In my experience, asdl works well for differentiable Jacobian computation. However, I would hope that curvlinops eventually offers such functionality as well since asdl is not actively maintained and, specifically, the differentiability aspect is only available in branches that have not been merged. It also is not present anymore in asdfghjkl, where differentiability broke for some reason I couldn't figure out.
Let J(x)SJ(x)^T: \R^k \to R^k
, i.e. (J(x)SJ(x)^T)(v)
for v \in \R^k
.
K
or want differentiability)Then vmap(J(x)SJ(x)^T)(I)
gives us the var(f(x))
. Note that this is vmap
over num classes, unlike the current Jacobian implementation.
def model_fn_params_only(params_dict, buffers_dict):
out = torch.func.functional_call(self.model, (params_dict, buffers_dict), x)
return out, out
_, vjpfunc = torch.func.vjp(model_fn_params_only, self.params_dict, has_aux=True)
_, jvpfunc = torch.func.jvp(model_fn_params_only, self.params_dict, has_aux=True)
def JSJT(v):
v = vjpfunc(v)
v = S @ v
v = jvpfunc(v)
return v
func_var = vmap(JSJT)(I)
stack
(k, k)
J(x)SJ(x)^T
explicitly
from J(x)
and S
since we only store (k, 1)
or (p, 1)
tensor each time.func_var = torch.stack([JSJT(v).detach() for v in I])
If only care about diag of func_var
:
func_var = torch.stack([JSJT(v).detach()[i] for (i, v) in enumerate(I)])
For sampling f(x)
this can be done cheaply (see Laurence paper).
For LLMs, this might be better because we don't really care about the explicit J(x)SJ(x)
,
but only the resulting \int softmax(f(x)) N(f(x) | f_\theta(x), J(x)SJ(x)^T) df(x)
I.e. the costs is now wrt. number of samples, instead of K
.
All of them can be optimized further depending on the form of S
vmap
We need a Jacobian backend that is: