Open tangbohu opened 4 years ago
I adapted the PowerIteration method to fit with batch matrix and eigenvectors if you are interested.
class PowerIteration(torch.autograd.Function):
@staticmethod
def forward(ctx, M, v, n_iter=19):
ctx.n_iter = n_iter
ctx.save_for_backward(M, v)
return v
@staticmethod
def backward(ctx, grad_output):
M, v = ctx.saved_tensors
dL_dv = grad_output
I = torch.eye(M.shape[-1], out=torch.empty_like(M)).reshape(1, M.shape[-1], M.shape[-1]).repeat(M.shape[0], 1, 1)
num = I - torch.bmm(v, torch.transpose(v, 2, 1))
denom = torch.norm(torch.bmm(M, v), dim=(1, 2), keepdim=True).clamp(min=1e-5)
ak = torch.div(num, denom)
term1 = ak.clone()
q = torch.div(M, denom)
for _ in range(1, ctx.n_iter + 1):
ak = torch.bmm(q, ak)
term1 += ak
dL_dM = torch.bmm(torch.bmm(term1, dL_dv), torch.transpose(v, 2, 1))
return dL_dM, ak
It is very interesting. But I wonder whether it could be easily extended to handle batch SVD? Thanks!