Closed dmarx closed 1 year ago
correction: https://discord.com/channels/729741769192767510/730484623028519072/1108585811143827496
def batch_hvp_fn(x, sigma, v):
"""Kronecker method"""
grad = grad_fn(x, sigma)
x_mean.mul_(beta).add_(grad, alpha=1 - beta)
x_kron_1.mul_(beta).add_(torch.einsum("nabc,nzbc->naz", grad, grad), alpha=(1 - beta) / (grad.shape[2] * grad.shape[3]))
x_kron_2.mul_(beta).add_(torch.einsum("nabc,nazc->nbz", grad, grad), alpha=(1 - beta) / (grad.shape[1] * grad.shape[3]))
x_kron_3.mul_(beta).add_(torch.einsum("nabc,nabz->ncz", grad, grad), alpha=(1 - beta) / (grad.shape[1] * grad.shape[2]))
beta_accum.mul_(beta)
x_mean_hat = x_mean / (1 - beta_accum)
x_kron_1_hat = x_kron_1 / (1 - beta_accum)
x_kron_2_hat = x_kron_2 / (1 - beta_accum)
x_kron_3_hat = x_kron_3 / (1 - beta_accum)
x_mean_1 = x_mean_hat.mean([2, 3])
x_mean_2 = x_mean_hat.mean([1, 3])
x_mean_3 = x_mean_hat.mean([1, 2])
x_cov_1 = x_kron_1_hat - torch.einsum("na,nb->nab", x_mean_1, x_mean_1)
x_cov_2 = x_kron_2_hat - torch.einsum("na,nb->nab", x_mean_2, x_mean_2)
x_cov_3 = x_kron_3_hat - torch.einsum("na,nb->nab", x_mean_3, x_mean_3)
x_cov_1_root = powmh(x_cov_1, 1 / 3, damping)
x_cov_2_root = powmh(x_cov_2, 1 / 3, damping)
x_cov_3_root = powmh(x_cov_3, 1 / 3, damping)
hvp = torch.einsum("xyabc,yad,ybe,ycf->xydef", v, x_cov_1_root, x_cov_2_root, x_cov_3_root)
return grad, hvp
https://discordapp.com/channels/729741769192767510/730484623028519072/1108512277671456808