dmarx / notebooks

misc notebooks i wanted to put in tracking
17 stars 2 forks source link

[klmc2] Kronecker HVP #28

Closed dmarx closed 1 year ago

dmarx commented 1 year ago

https://discordapp.com/channels/729741769192767510/730484623028519072/1108512277671456808

    # State for the Kronecker HVP approximation
    x_mean = torch.zeros_like(x, device=x.device)
    x_kron_1 = torch.zeros([x.shape[0], x.shape[1], x.shape[1]], device=x.device)
    x_kron_2 = torch.zeros([x.shape[0], x.shape[2], x.shape[2]], device=x.device)
    x_kron_3 = torch.zeros([x.shape[0], x.shape[3], x.shape[3]], device=x.device)
    beta_accum = torch.tensor(1.0, device=x.device)
    beta = torch.tensor(0.99, device=x.device)
    damping = torch.tensor(0.0, device=x.device)

...

    def powmh(x, power, damping=0.0):
        vals, vecs = torch.linalg.eigh(x)
        vals = vals.abs().add(damping).pow(power)
        return torch.einsum("...ab,...b,...cb->...ac", vecs, vals, vecs)

    def batch_hvp_fn(x, sigma, v):
        """Kronecker method"""
        grad = grad_fn(x, sigma)
        x_mean.mul_(beta).add_(grad, alpha=1 - beta)
        x_kron_1.mul_(beta).add_(torch.einsum("nabc,nzbc->naz", grad, grad), alpha=(1 - beta) / grad.shape[2] / grad.shape[3])
        x_kron_2.mul_(beta).add_(torch.einsum("nabc,nazc->nbz", grad, grad), alpha=(1 - beta) / grad.shape[1] / grad.shape[3])
        x_kron_3.mul_(beta).add_(torch.einsum("nabc,nabz->ncz", grad, grad), alpha=(1 - beta) / grad.shape[1] / grad.shape[2])
        beta_accum.mul_(beta)
        x_mean_hat = x_mean / (1 - beta_accum)
        x_kron_1_hat = x_kron_1 / (1 - beta_accum)
        x_kron_2_hat = x_kron_2 / (1 - beta_accum)
        x_kron_3_hat = x_kron_3 / (1 - beta_accum)
        x_cov_1 = x_kron_1_hat - x_mean_hat.mean([2, 3]) ** 2
        x_cov_2 = x_kron_2_hat - x_mean_hat.mean([1, 3]) ** 2
        x_cov_3 = x_kron_3_hat - x_mean_hat.mean([1, 2]) ** 2
        x_cov_1_root = powmh(x_cov_1, 1 / 3, damping)
        x_cov_2_root = powmh(x_cov_2, 1 / 3, damping)
        x_cov_3_root = powmh(x_cov_3, 1 / 3, damping)
        hvp = torch.einsum("xyabc,yad,ybe,ycf->xydef", v, x_cov_1_root, x_cov_2_root, x_cov_3_root)
        return grad, hvp
dmarx commented 1 year ago

correction: https://discord.com/channels/729741769192767510/730484623028519072/1108585811143827496

    def batch_hvp_fn(x, sigma, v):
        """Kronecker method"""
        grad = grad_fn(x, sigma)
        x_mean.mul_(beta).add_(grad, alpha=1 - beta)
        x_kron_1.mul_(beta).add_(torch.einsum("nabc,nzbc->naz", grad, grad), alpha=(1 - beta) / (grad.shape[2] * grad.shape[3]))
        x_kron_2.mul_(beta).add_(torch.einsum("nabc,nazc->nbz", grad, grad), alpha=(1 - beta) / (grad.shape[1] * grad.shape[3]))
        x_kron_3.mul_(beta).add_(torch.einsum("nabc,nabz->ncz", grad, grad), alpha=(1 - beta) / (grad.shape[1] * grad.shape[2]))
        beta_accum.mul_(beta)
        x_mean_hat = x_mean / (1 - beta_accum)
        x_kron_1_hat = x_kron_1 / (1 - beta_accum)
        x_kron_2_hat = x_kron_2 / (1 - beta_accum)
        x_kron_3_hat = x_kron_3 / (1 - beta_accum)
        x_mean_1 = x_mean_hat.mean([2, 3])
        x_mean_2 = x_mean_hat.mean([1, 3])
        x_mean_3 = x_mean_hat.mean([1, 2])
        x_cov_1 = x_kron_1_hat - torch.einsum("na,nb->nab", x_mean_1, x_mean_1)
        x_cov_2 = x_kron_2_hat - torch.einsum("na,nb->nab", x_mean_2, x_mean_2)
        x_cov_3 = x_kron_3_hat - torch.einsum("na,nb->nab", x_mean_3, x_mean_3)
        x_cov_1_root = powmh(x_cov_1, 1 / 3, damping)
        x_cov_2_root = powmh(x_cov_2, 1 / 3, damping)
        x_cov_3_root = powmh(x_cov_3, 1 / 3, damping)
        hvp = torch.einsum("xyabc,yad,ybe,ycf->xydef", v, x_cov_1_root, x_cov_2_root, x_cov_3_root)
        return grad, hvp