jiangtaoxie / fast-MPN-COV

@CVPR2018: Efficient unrolling iterative matrix square-root normalized ConvNets, implemented by PyTorch (and code of B-CNN,Compact bilinear pooling etc.) for training from scratch & finetuning.
http://peihuali.org/iSQRT-COV/index.html
MIT License
270 stars 56 forks source link

fix Sqrtm forward / backward implementation #7

Closed WarBean closed 5 years ago

WarBean commented 5 years ago

经过对MPNCOV.py里面的各个Function的测试,发现Sqrtm的前向后向似乎有两个问题:

1.当Sqrtm输入的batchSize x dim x dim矩阵不是对称阵的时候,后向梯度不对; 2.当Sqrtm的iterN为1的时候,前向结果和后向梯度不对。

只不过论文中刚好都是输入对称阵、iterN > 1的情况,所以没触发上述错误。

测试方法:跟PyTorch autograd出来的结果进行对比

测试脚本如下,其中注释为“修改(1)”、“修改(2)”的地方分别修复上述两个问题:

import torch
import numpy as np
from torch.autograd import Function

class Sqrtm(Function):
    @staticmethod
    def forward(ctx, input, iterN):
        x = input
        batchSize = x.data.shape[0]
        dim = x.data.shape[1]
        dtype = x.dtype
        I3 = 3.0*torch.eye(dim,dim,device = x.device).view(1, dim, dim).repeat(batchSize,1,1).type(dtype)
        normA = (1.0/3.0)*x.mul(I3).sum(dim=1).sum(dim=1)
        A = x.div(normA.view(batchSize,1,1).expand_as(x))
        Y = torch.zeros(batchSize, iterN, dim, dim, requires_grad = False, device = x.device).type(dtype)
        Z = torch.eye(dim,dim,device = x.device).view(1,dim,dim).repeat(batchSize,iterN,1,1).type(dtype)
        if iterN < 2:
            ZY = 0.5*(I3 - A)
            # 修改(2)
            # Y[:,0,:,:] = A.bmm(ZY)
            YZY = A.bmm(ZY)
        else:
            ZY = 0.5*(I3 - A)
            Y[:,0,:,:] = A.bmm(ZY)
            Z[:,0,:,:] = ZY
            for i in range(1, iterN-1):
                ZY = 0.5*(I3 - Z[:,i-1,:,:].bmm(Y[:,i-1,:,:]))
                Y[:,i,:,:] = Y[:,i-1,:,:].bmm(ZY)
                Z[:,i,:,:] = ZY.bmm(Z[:,i-1,:,:])
            # 修改(2)
            # ZY = 0.5*Y[:,iterN-2,:,:].bmm(I3 - Z[:,iterN-2,:,:].bmm(Y[:,iterN-2,:,:]))
            YZY = 0.5*Y[:,iterN-2,:,:].bmm(I3 - Z[:,iterN-2,:,:].bmm(Y[:,iterN-2,:,:]))
        # 修改(2)
        # y = ZY*torch.sqrt(normA).view(batchSize, 1, 1).expand_as(x)
        y = YZY*torch.sqrt(normA).view(batchSize, 1, 1).expand_as(x)
        # 修改(2)
        # ctx.save_for_backward(input, A, ZY, normA, Y, Z)
        ctx.save_for_backward(input, A, YZY, normA, Y, Z)
        ctx.iterN = iterN
        return y
    @staticmethod
    def backward(ctx, grad_output):
        input, A, ZY, normA, Y, Z = ctx.saved_tensors
        iterN = ctx.iterN
        x = input
        batchSize = x.data.shape[0]
        dim = x.data.shape[1]
        dtype = x.dtype
        der_postCom = grad_output*torch.sqrt(normA).view(batchSize, 1, 1).expand_as(x)
        der_postComAux = (grad_output*ZY).sum(dim=1).sum(dim=1).div(2*torch.sqrt(normA))
        I3 = 3.0*torch.eye(dim,dim,device = x.device).view(1, dim, dim).repeat(batchSize,1,1).type(dtype)
        if iterN < 2:
            der_NSiter = 0.5*(der_postCom.bmm(I3 - A) - A.bmm(der_postCom))
        else:
            dldY = 0.5*(der_postCom.bmm(I3 - Y[:,iterN-2,:,:].bmm(Z[:,iterN-2,:,:])) -
                         Z[:,iterN-2,:,:].bmm(Y[:,iterN-2,:,:]).bmm(der_postCom))
            dldZ = -0.5*Y[:,iterN-2,:,:].bmm(der_postCom).bmm(Y[:,iterN-2,:,:])
            for i in range(iterN-3, -1, -1):
                YZ = I3 - Y[:,i,:,:].bmm(Z[:,i,:,:])
                ZY = Z[:,i,:,:].bmm(Y[:,i,:,:])
                dldY_ = 0.5*(dldY.bmm(YZ) -
                          Z[:,i,:,:].bmm(dldZ).bmm(Z[:,i,:,:]) -
                              ZY.bmm(dldY))
                dldZ_ = 0.5*(YZ.bmm(dldZ) -
                          Y[:,i,:,:].bmm(dldY).bmm(Y[:,i,:,:]) -
                             dldZ.bmm(ZY))
                dldY = dldY_
                dldZ = dldZ_
            der_NSiter = 0.5*(dldY.bmm(I3 - A) - dldZ - A.bmm(dldY))
        # 修改(1)
        der_NSiter = der_NSiter.transpose(1, 2)
        grad_input = der_NSiter.div(normA.view(batchSize,1,1).expand_as(x))
        grad_aux = der_NSiter.mul(x).sum(dim=1).sum(dim=1)
        for i in range(batchSize):
            grad_input[i,:,:] += (der_postComAux[i] \
                                  - grad_aux[i] / (normA[i] * normA[i])) \
                                  *torch.ones(dim,device = x.device).diag().type(dtype)
        return grad_input, None

def sqrtm1(x, iterN):
    return Sqrtm.apply(x, iterN)

def sqrtm2(x, iterN):
    batchSize = x.shape[0]
    dim = x.shape[1]
    dtype = x.dtype
    I3 = 3.0 * torch.eye(dim, dim, device=x.device).view(1, dim, dim).repeat(batchSize, 1, 1).type(dtype)
    normA = (1.0 / 3.0) * x.mul(I3).sum(dim=1).sum(dim=1)
    A = x.div(normA.view(batchSize, 1, 1).expand_as(x))
    ZY = 0.5 * (I3 - A)
    if iterN < 2:
        ZY = 0.5*(I3 - A)
        YZY = A.bmm(ZY)
    else:
        Y = A.bmm(ZY)
        Z = ZY
        for _ in range(iterN - 2):
            ZY = 0.5 * (I3 - Z.bmm(Y))
            Y = Y.bmm(ZY)
            Z = ZY.bmm(Z)
        YZY = 0.5 * Y.bmm(I3 - Z.bmm(Y))
    y = YZY * torch.sqrt(normA).view(batchSize, 1, 1).expand_as(x)
    return y

def test(input_array, iterN):
    input_tensor = torch.from_numpy(input_array).requires_grad_()
    out1 = sqrtm1(input_tensor, iterN)
    out1.sum().backward()
    grad1 = input_tensor.grad.clone()

    input_tensor = torch.from_numpy(input_array).requires_grad_()
    out2 = sqrtm2(input_tensor, iterN)
    out2.sum().backward()
    grad2 = input_tensor.grad.clone()

    print('gradient from sqrtm1:\n{}'.format(grad1))
    print('gradient from sqrtm2:\n{}'.format(grad2))
    print('diff between gradients:\n{}'.format(grad1 - grad2))
    print('max abs diff between outputs: {}'.format((out1 - out2).abs().max()))
    print('max abs diff between gradients: {}'.format((grad1 - grad2).abs().max()))

B, D = 2, 4
iterN = 2
asymetric_input_array = np.random.random([B, D, D])
symetric_input_array = asymetric_input_array + asymetric_input_array.swapaxes(1, 2)
print('------------------ test with asymetric input:')
test(asymetric_input_array, iterN)
print('------------------ test with symetric input:')
test(symetric_input_array, iterN)

修复前的iterN = 1测试结果,对称和非对称都出错:

------------------ test with asymetric input:
gradient from sqrtm1:
tensor([[[0.6670, 0.4382, 0.4621, 0.5066],
         [0.5800, 0.5628, 0.4809, 0.5254],
         [0.5296, 0.4066, 0.5362, 0.4750],
         [0.4549, 0.3319, 0.3557, 0.5061]],

        [[0.5720, 0.5708, 0.5501, 0.5786],
         [0.4222, 0.6016, 0.4906, 0.5191],
         [0.4158, 0.5049, 0.5744, 0.5127],
         [0.4329, 0.5220, 0.5013, 0.6201]]], dtype=torch.float64)
gradient from sqrtm2:
tensor([[[0.1864, 0.5800, 0.5296, 0.4549],
         [0.4382, 0.0822, 0.4066, 0.3319],
         [0.4621, 0.4809, 0.0556, 0.3557],
         [0.5066, 0.5254, 0.4750, 0.0255]],

        [[0.0639, 0.4222, 0.4158, 0.4329],
         [0.5708, 0.0935, 0.5049, 0.5220],
         [0.5501, 0.4906, 0.0663, 0.5013],
         [0.5786, 0.5191, 0.5127, 0.1120]]], dtype=torch.float64)
diff between gradients:
tensor([[[ 0.4806, -0.1418, -0.0675,  0.0518],
         [ 0.1418,  0.4806,  0.0743,  0.1935],
         [ 0.0675, -0.0743,  0.4806,  0.1192],
         [-0.0518, -0.1935, -0.1192,  0.4806]],

        [[ 0.5081,  0.1486,  0.1343,  0.1457],
         [-0.1486,  0.5081, -0.0143, -0.0029],
         [-0.1343,  0.0143,  0.5081,  0.0114],
         [-0.1457,  0.0029, -0.0114,  0.5081]]], dtype=torch.float64)
max abs diff between outputs: 2.186709279694831
max abs diff between gradients: 0.5081247212245852
------------------ test with symetric input:
gradient from sqrtm1:
tensor([[[0.4454, 0.3600, 0.3506, 0.3399],
         [0.3600, 0.3717, 0.3137, 0.3031],
         [0.3506, 0.3137, 0.3529, 0.2937],
         [0.3399, 0.3031, 0.2937, 0.3316]],

        [[0.3865, 0.3511, 0.3415, 0.3576],
         [0.3511, 0.4074, 0.3520, 0.3681],
         [0.3415, 0.3520, 0.3882, 0.3585],
         [0.3576, 0.3681, 0.3585, 0.4205]]], dtype=torch.float64)
gradient from sqrtm2:
tensor([[[0.1515, 0.3600, 0.3506, 0.3399],
         [0.3600, 0.0778, 0.3137, 0.3031],
         [0.3506, 0.3137, 0.0590, 0.2937],
         [0.3399, 0.3031, 0.2937, 0.0377]],

        [[0.0587, 0.3511, 0.3415, 0.3576],
         [0.3511, 0.0796, 0.3520, 0.3681],
         [0.3415, 0.3520, 0.0604, 0.3585],
         [0.3576, 0.3681, 0.3585, 0.0927]]], dtype=torch.float64)
diff between gradients:
tensor([[[0.2940, 0.0000, 0.0000, 0.0000],
         [0.0000, 0.2940, 0.0000, 0.0000],
         [0.0000, 0.0000, 0.2940, 0.0000],
         [0.0000, 0.0000, 0.0000, 0.2940]],

        [[0.3278, 0.0000, 0.0000, 0.0000],
         [0.0000, 0.3278, 0.0000, 0.0000],
         [0.0000, 0.0000, 0.3278, 0.0000],
         [0.0000, 0.0000, 0.0000, 0.3278]]], dtype=torch.float64)
max abs diff between outputs: 3.1037870123546973
max abs diff between gradients: 0.3277939043459598

修复前的iterN = 5测试结果,对称没错,非对称出错:

------------------ test with asymetric input:
gradient from sqrtm1:
tensor([[[ 1.9384e+07, -8.5210e+06, -4.6972e+06, -1.2790e+06],
         [-1.5915e+07,  2.9645e+07,  8.0555e+06,  1.4257e+06],
         [-4.5993e+06,  4.9725e+06,  1.4856e+07,  4.2368e+05],
         [ 9.8476e+05, -1.9989e+06, -1.5218e+06,  1.1861e+07]],

        [[ 1.3860e+02,  1.1201e+02,  1.2900e+02, -2.8761e+02],
         [ 3.5804e+01,  1.0327e+02,  5.1839e+01, -1.0884e+02],
         [ 3.1528e+01,  4.2369e+01,  1.0474e+02, -1.0514e+02],
         [-9.0499e+00, -1.9190e+00,  7.6272e+00,  8.5914e+00]]],
       dtype=torch.float64)
gradient from sqrtm2:
tensor([[[ 1.7547e+07, -1.5915e+07, -4.5993e+06,  9.8476e+05],
         [-8.5210e+06,  2.7808e+07,  4.9725e+06, -1.9989e+06],
         [-4.6972e+06,  8.0555e+06,  1.3020e+07, -1.5218e+06],
         [-1.2790e+06,  1.4257e+06,  4.2368e+05,  1.0024e+07]],

        [[ 1.0755e+02,  3.5804e+01,  3.1528e+01, -9.0499e+00],
         [ 1.1201e+02,  7.2223e+01,  4.2369e+01, -1.9190e+00],
         [ 1.2900e+02,  5.1839e+01,  7.3690e+01,  7.6272e+00],
         [-2.8761e+02, -1.0884e+02, -1.0514e+02, -2.2455e+01]]],
       dtype=torch.float64)
diff between gradients:
tensor([[[ 1.8366e+06,  7.3936e+06, -9.7904e+04, -2.2638e+06],
         [-7.3936e+06,  1.8366e+06,  3.0830e+06,  3.4246e+06],
         [ 9.7904e+04, -3.0830e+06,  1.8366e+06,  1.9455e+06],
         [ 2.2638e+06, -3.4246e+06, -1.9455e+06,  1.8366e+06]],

        [[ 3.1047e+01,  7.6207e+01,  9.7471e+01, -2.7856e+02],
         [-7.6207e+01,  3.1047e+01,  9.4698e+00, -1.0692e+02],
         [-9.7471e+01, -9.4698e+00,  3.1047e+01, -1.1276e+02],
         [ 2.7856e+02,  1.0692e+02,  1.1276e+02,  3.1047e+01]]],
       dtype=torch.float64)
max abs diff between outputs: 0.0
max abs diff between gradients: 7393625.618993804
------------------ test with symetric input:
gradient from sqrtm1:
tensor([[[ 3.0105e+07, -1.9550e+07, -7.4120e+06, -1.1047e+06],
         [-1.9550e+07,  4.7142e+07,  1.0035e+07, -2.2143e+05],
         [-7.4120e+06,  1.0035e+07,  2.1771e+07, -5.0680e+05],
         [-1.1047e+06, -2.2143e+05, -5.0680e+05,  1.7518e+07]],

        [[ 2.2905e+02,  1.1386e+02,  1.0183e+02, -2.0465e+02],
         [ 1.1386e+02,  1.8236e+02,  8.0801e+01, -1.4623e+02],
         [ 1.0183e+02,  8.0801e+01,  1.6310e+02, -1.2689e+02],
         [-2.0465e+02, -1.4623e+02, -1.2689e+02,  2.3370e+02]]],
       dtype=torch.float64)
gradient from sqrtm2:
tensor([[[ 3.0105e+07, -1.9550e+07, -7.4120e+06, -1.1047e+06],
         [-1.9550e+07,  4.7142e+07,  1.0035e+07, -2.2143e+05],
         [-7.4120e+06,  1.0035e+07,  2.1771e+07, -5.0680e+05],
         [-1.1047e+06, -2.2143e+05, -5.0680e+05,  1.7518e+07]],

        [[ 2.2905e+02,  1.1386e+02,  1.0183e+02, -2.0465e+02],
         [ 1.1386e+02,  1.8236e+02,  8.0801e+01, -1.4623e+02],
         [ 1.0183e+02,  8.0801e+01,  1.6310e+02, -1.2689e+02],
         [-2.0465e+02, -1.4623e+02, -1.2689e+02,  2.3370e+02]]],
       dtype=torch.float64)
diff between gradients:
tensor([[[-7.4506e-09,  1.4901e-08,  6.5193e-09,  2.3283e-10],
         [-3.7253e-09, -7.4506e-09, -1.8626e-09, -3.3469e-09],
         [-4.6566e-09,  1.8626e-09, -3.7253e-09, -1.3388e-09],
         [-4.8894e-09,  6.6939e-09,  2.3865e-09, -3.7253e-09]],

        [[-2.8422e-14,  1.1369e-13,  8.5265e-14, -3.6948e-13],
         [-9.9476e-14, -2.8422e-14,  9.9476e-14, -1.7053e-13],
         [-7.1054e-14, -9.9476e-14,  0.0000e+00, -1.2790e-13],
         [ 3.4106e-13,  2.8422e-13,  4.2633e-14, -8.5265e-14]]],
       dtype=torch.float64)
max abs diff between outputs: 0.0
max abs diff between gradients: 1.4901161193847656e-08

修复后的iterN = 1测试结果,都没出错:

------------------ test with asymetric input:
gradient from sqrtm1:
tensor([[[ 3.9705, -0.3688, -0.1363,  0.3287],
         [-0.2902,  3.6506, -0.2570,  0.2081],
         [-0.5384, -0.7377,  3.6348, -0.0402],
         [-0.7213, -0.9205, -0.6881,  3.9170]],

        [[ 0.3193,  0.4983,  0.3268,  0.3002],
         [ 0.4432,  0.6263,  0.4528,  0.4262],
         [ 0.2070,  0.3881,  0.2186,  0.1900],
         [ 0.2688,  0.4498,  0.2783,  0.2538]]], dtype=torch.float64)
gradient from sqrtm2:
tensor([[[ 3.9705, -0.3688, -0.1363,  0.3287],
         [-0.2902,  3.6506, -0.2570,  0.2081],
         [-0.5384, -0.7377,  3.6348, -0.0402],
         [-0.7213, -0.9205, -0.6881,  3.9170]],

        [[ 0.3193,  0.4983,  0.3268,  0.3002],
         [ 0.4432,  0.6263,  0.4528,  0.4262],
         [ 0.2070,  0.3881,  0.2186,  0.1900],
         [ 0.2688,  0.4498,  0.2783,  0.2538]]], dtype=torch.float64)
diff between gradients:
tensor([[[ 0.0000e+00, -5.5511e-17, -8.3267e-17,  0.0000e+00],
         [ 1.1102e-16,  0.0000e+00,  0.0000e+00,  1.1102e-16],
         [ 2.2204e-16,  2.2204e-16,  0.0000e+00,  2.1511e-16],
         [ 2.2204e-16,  2.2204e-16,  1.1102e-16,  0.0000e+00]],

        [[ 0.0000e+00, -5.5511e-17,  5.5511e-17,  5.5511e-17],
         [ 0.0000e+00,  0.0000e+00,  1.1102e-16,  1.1102e-16],
         [-8.3267e-17, -2.2204e-16,  0.0000e+00,  0.0000e+00],
         [-5.5511e-17, -1.6653e-16,  0.0000e+00,  0.0000e+00]]],
       dtype=torch.float64)
max abs diff between outputs: 0.0
max abs diff between gradients: 2.220446049250313e-16
------------------ test with symetric input:
gradient from sqrtm1:
tensor([[[ 3.2350, -0.2330, -0.2386, -0.1388],
         [-0.2330,  3.0088, -0.3517, -0.2519],
         [-0.2386, -0.3517,  2.9976, -0.2575],
         [-0.1388, -0.2519, -0.2575,  3.1972]],

        [[ 0.2323,  0.3329,  0.1887,  0.2012],
         [ 0.3329,  0.4494,  0.2973,  0.3097],
         [ 0.1887,  0.2973,  0.1611,  0.1656],
         [ 0.2012,  0.3097,  0.1656,  0.1860]]], dtype=torch.float64)
gradient from sqrtm2:
tensor([[[ 3.2350, -0.2330, -0.2386, -0.1388],
         [-0.2330,  3.0088, -0.3517, -0.2519],
         [-0.2386, -0.3517,  2.9976, -0.2575],
         [-0.1388, -0.2519, -0.2575,  3.1972]],

        [[ 0.2323,  0.3329,  0.1887,  0.2012],
         [ 0.3329,  0.4494,  0.2973,  0.3097],
         [ 0.1887,  0.2973,  0.1611,  0.1656],
         [ 0.2012,  0.3097,  0.1656,  0.1860]]], dtype=torch.float64)
diff between gradients:
tensor([[[ 4.4409e-16,  0.0000e+00, -1.1102e-16, -1.1102e-16],
         [ 0.0000e+00,  4.4409e-16, -1.1102e-16, -5.5511e-17],
         [ 1.1102e-16,  1.1102e-16,  4.4409e-16,  0.0000e+00],
         [ 1.1102e-16,  5.5511e-17,  0.0000e+00,  4.4409e-16]],

        [[ 1.1102e-16,  0.0000e+00, -5.5511e-17,  2.7756e-17],
         [ 0.0000e+00,  1.1102e-16,  0.0000e+00,  5.5511e-17],
         [ 5.5511e-17,  0.0000e+00,  1.1102e-16,  8.3267e-17],
         [-2.7756e-17, -5.5511e-17, -8.3267e-17,  1.1102e-16]]],
       dtype=torch.float64)
max abs diff between outputs: 0.0
max abs diff between gradients: 4.440892098500626e-16

修复后的iterN = 5测试结果,都没出错:

------------------ test with asymetric input:
gradient from sqrtm1:
tensor([[[ 2.6101e+02,  1.6915e+02,  3.6617e+01, -1.1991e+02],
         [-2.6940e+02,  4.9514e+02,  3.0622e+02, -1.0844e+02],
         [-6.7110e+01, -1.5917e+02,  2.2811e+02,  1.1199e+02],
         [ 1.8232e+02, -2.0789e+02, -2.2965e+02,  3.4158e+02]],

        [[-8.9861e+05, -1.3094e+07,  4.7909e+06,  1.3310e+07],
         [ 4.3960e+06,  3.6968e+06,  1.4234e+06, -2.2849e+07],
         [ 1.0032e+07, -1.0482e+07,  5.3876e+06, -2.0861e+06],
         [-6.4518e+06, -4.6163e+05, -5.5865e+06,  3.4330e+06]]],
       dtype=torch.float64)
gradient from sqrtm2:
tensor([[[ 2.6101e+02,  1.6915e+02,  3.6617e+01, -1.1991e+02],
         [-2.6940e+02,  4.9514e+02,  3.0622e+02, -1.0844e+02],
         [-6.7110e+01, -1.5917e+02,  2.2811e+02,  1.1199e+02],
         [ 1.8232e+02, -2.0789e+02, -2.2965e+02,  3.4158e+02]],

        [[-8.9861e+05, -1.3094e+07,  4.7909e+06,  1.3310e+07],
         [ 4.3960e+06,  3.6968e+06,  1.4234e+06, -2.2849e+07],
         [ 1.0032e+07, -1.0482e+07,  5.3876e+06, -2.0861e+06],
         [-6.4518e+06, -4.6163e+05, -5.5865e+06,  3.4330e+06]]],
       dtype=torch.float64)
diff between gradients:
tensor([[[ 1.1369e-13, -1.4211e-13, -7.1054e-14,  2.8422e-14],
         [ 1.7053e-13, -1.1369e-13, -2.2737e-13, -4.2633e-14],
         [-4.2633e-14,  1.1369e-13,  0.0000e+00, -4.2633e-14],
         [-1.9895e-13,  5.6843e-14,  1.4211e-13,  0.0000e+00]],

        [[-2.0955e-08, -3.7253e-09,  6.5193e-09,  1.6764e-08],
         [ 1.2107e-08, -1.6764e-08, -1.3970e-09, -2.9802e-08],
         [ 3.7253e-09, -7.4506e-09, -1.3039e-08,  9.0804e-09],
         [-1.6764e-08,  9.0804e-09, -1.5832e-08, -1.8626e-08]]],
       dtype=torch.float64)
max abs diff between outputs: 0.0
max abs diff between gradients: 2.9802322387695312e-08
------------------ test with symetric input:
gradient from sqrtm1:
tensor([[[ 1.3715e+05,  5.3459e+04, -2.6848e+04, -3.0884e+04],
         [ 5.3459e+04,  5.7213e+04, -2.9028e+04, -1.7253e+04],
         [-2.6848e+04, -2.9028e+04,  4.2635e+03,  5.4074e+03],
         [-3.0884e+04, -1.7253e+04,  5.4074e+03,  4.4265e+04]],

        [[ 1.3055e+10, -4.3848e+09,  1.5062e+09,  4.1697e+09],
         [-4.3848e+09,  1.9282e+10, -4.4168e+09, -9.0337e+09],
         [ 1.5062e+09, -4.4168e+09,  1.1490e+10,  3.3433e+09],
         [ 4.1697e+09, -9.0337e+09,  3.3433e+09,  1.9858e+10]]],
       dtype=torch.float64)
gradient from sqrtm2:
tensor([[[ 1.3715e+05,  5.3459e+04, -2.6848e+04, -3.0884e+04],
         [ 5.3459e+04,  5.7213e+04, -2.9028e+04, -1.7253e+04],
         [-2.6848e+04, -2.9028e+04,  4.2635e+03,  5.4074e+03],
         [-3.0884e+04, -1.7253e+04,  5.4074e+03,  4.4265e+04]],

        [[ 1.3055e+10, -4.3848e+09,  1.5062e+09,  4.1697e+09],
         [-4.3848e+09,  1.9282e+10, -4.4168e+09, -9.0337e+09],
         [ 1.5062e+09, -4.4168e+09,  1.1490e+10,  3.3433e+09],
         [ 4.1697e+09, -9.0337e+09,  3.3433e+09,  1.9858e+10]]],
       dtype=torch.float64)
diff between gradients:
tensor([[[-8.7311e-11, -4.3656e-11,  4.0018e-11,  1.8190e-11],
         [-2.9104e-11, -5.8208e-11,  2.5466e-11,  1.0914e-11],
         [ 4.3656e-11,  1.8190e-11, -9.4587e-11, -1.6371e-11],
         [ 2.1828e-11,  3.6380e-12, -2.4556e-11, -4.3656e-11]],

        [[-5.7220e-06,  9.5367e-07, -2.6226e-06, -2.3842e-06],
         [ 2.8610e-06, -1.1444e-05,  5.7220e-06,  7.6294e-06],
         [ 2.3842e-07,  0.0000e+00, -5.7220e-06,  1.4305e-06],
         [-4.2915e-06,  3.8147e-06, -3.8147e-06, -1.1444e-05]]],
       dtype=torch.float64)
max abs diff between outputs: 0.0
max abs diff between gradients: 1.1444091796875e-05
jiangtaoxie commented 5 years ago

@WarBean Great! Thanks for your amazing work. These are problems i have never found before, due to the commonly used settings are iterN=5, input=symm. Thank you again!