cvlab-epfl / Power-Iteration-SVD

Backpropagation-Friendly-Eigendecomposition
72 stars 12 forks source link

Extend to batch SVD #1

Open tangbohu opened 4 years ago

tangbohu commented 4 years ago

It is very interesting. But I wonder whether it could be easily extended to handle batch SVD? Thanks!

Pikauba commented 4 years ago

I adapted the PowerIteration method to fit with batch matrix and eigenvectors if you are interested.

class PowerIteration(torch.autograd.Function):
    @staticmethod
    def forward(ctx, M, v, n_iter=19):
        ctx.n_iter = n_iter
        ctx.save_for_backward(M, v)

        return v

    @staticmethod
    def backward(ctx, grad_output):
        M, v = ctx.saved_tensors
        dL_dv = grad_output
        I = torch.eye(M.shape[-1], out=torch.empty_like(M)).reshape(1, M.shape[-1], M.shape[-1]).repeat(M.shape[0], 1, 1)
        num = I - torch.bmm(v, torch.transpose(v, 2, 1))
        denom = torch.norm(torch.bmm(M, v), dim=(1, 2), keepdim=True).clamp(min=1e-5)
        ak = torch.div(num, denom)
        term1 = ak.clone()
        q = torch.div(M, denom)
        for _ in range(1, ctx.n_iter + 1):
            ak = torch.bmm(q, ak)
            term1 += ak

        dL_dM = torch.bmm(torch.bmm(term1, dL_dv), torch.transpose(v, 2, 1))

        return dL_dM, ak