pytorch-labs / segment-anything-fast

A batched offline inference oriented version of segment-anything
Apache License 2.0
1.19k stars 70 forks source link

the results on cpu and gpu are very different #121

Open zhangvia opened 4 months ago

zhangvia commented 4 months ago

there is a short test to debug. as you can see,the ImageEcnoderVit will get different result on cpu and gpu.this precision are not acceptable.

and weirdly, when i print the x before pass x to self.neck layer. the x on cpu and gpu are the same. self.neck are just conv2d and layernorm. i have no idea why self.neck can cause different result on cpu and gpu. can anyone help?

>>> from segment_anything_fast.modeling.image_encoder import ImageEncoderViT
>>> import torch
>>> from functools import partial
>>> model = ImageEncoderViT(depth=32,embed_dim=1280,img_size=1024,mlp_ratio=4,norm_layer=partial(torch.nn.LayerNorm, eps=1e-6),num_heads=16,patch_size=16,qkv_bias=True,use_rel_pos=True,global_attn_indexes=[7, 15, 23, 31],window_size=14,out_chans=256,)
>>> input = torch.randn((1,3,1024,1024))
>>> a = model(input)
>>> input = input.to('cuda')
>>> model = model.to('cuda')
>>> b = model(input)
>>> a
tensor([[[[ 1.1279e+00, -3.9965e-01,  1.3351e+00,  ...,  6.2123e-01,
           -2.6559e-01,  8.9881e-01],
          [ 7.6766e-01,  1.1112e+00,  9.9850e-01,  ...,  8.1880e-01,
            4.4002e-01, -1.1236e+00],
          [ 4.8820e-01,  1.3826e+00,  1.6246e+00,  ...,  5.2780e-01,
           -1.2887e+00, -1.8294e-01],
          ...,
          [ 5.3505e-01,  6.5830e-01,  8.6872e-01,  ..., -5.3345e-02,
           -2.3688e-01, -5.8527e-01],
          [ 8.2007e-01,  1.1339e+00,  6.3609e-01,  ...,  1.2077e+00,
           -1.3440e+00,  1.3187e-02],
          [-1.8615e-01,  9.4821e-01,  5.8698e-01,  ..., -2.2219e-02,
            2.2753e-01, -1.0425e+00]],

         [[-5.4213e-01,  6.6123e-02,  7.4238e-01,  ..., -8.3169e-01,
           -5.6192e-01, -1.2276e+00],
          [ 3.6463e-01,  3.3585e-01,  1.0749e+00,  ...,  1.3571e+00,
           -3.0437e-01,  6.2709e-01],
          [-3.5779e-03, -3.8753e-03,  1.0144e+00,  ...,  1.5758e-01,
           -3.4319e-01,  1.0236e+00],
          ...,
          [ 5.7304e-01,  1.6791e+00,  7.4203e-01,  ...,  2.3011e+00,
            1.2866e+00,  1.0487e+00],
          [ 1.0945e+00,  1.3373e+00,  6.1515e-01,  ...,  1.9827e+00,
            1.4084e+00,  1.2886e+00],
          [ 2.2072e-01,  9.9274e-01, -4.1084e-02,  ...,  4.0965e-01,
            1.0271e+00,  7.0089e-01]],

         [[ 9.2292e-01,  1.2548e+00,  1.4542e+00,  ...,  1.3106e+00,
           -4.0816e-01,  1.3618e+00],
          [ 1.5234e+00,  5.9021e-01,  1.1496e+00,  ...,  1.1651e+00,
            1.6763e+00,  4.5780e-01],
          [ 9.5586e-01,  1.3678e+00,  1.6812e+00,  ...,  1.8888e-01,
            5.9193e-01,  2.3960e+00],
          ...,
          [ 1.3847e+00,  1.3367e+00, -2.9702e-02,  ...,  7.4956e-01,
            9.3687e-01,  1.1981e+00],
          [ 1.0065e+00,  9.2275e-01,  1.3572e+00,  ...,  2.0780e+00,
            8.4323e-01,  3.0510e-01],
          [ 1.4033e+00,  1.1534e+00, -5.7349e-01,  ..., -1.2417e-03,
            1.1398e+00,  1.4818e+00]],

         ...,

         [[-1.0704e+00, -1.4728e-02, -2.9317e-01,  ...,  1.2236e+00,
           -1.2223e-01, -1.0641e+00],
          [-2.4054e-01, -1.3002e-01, -3.1265e-02,  ...,  7.6334e-01,
           -1.5491e-01, -6.8875e-01],
          [ 4.5375e-01,  5.3694e-01, -7.1474e-01,  ...,  9.6021e-01,
            8.3933e-03,  7.1659e-01],
          ...,
          [-4.5589e-01,  1.0988e+00, -7.2541e-02,  ...,  6.7271e-01,
            1.7812e+00, -1.5954e+00],
          [ 4.6227e-01, -4.6240e-01, -1.4692e-02,  ...,  2.4024e+00,
           -1.4458e+00, -8.7657e-01],
          [-9.7614e-01, -6.4876e-01,  6.0639e-01,  ..., -1.2481e+00,
           -3.0824e-01,  1.1255e+00]],

         [[ 9.7512e-01,  6.5605e-01,  9.7539e-01,  ...,  1.1063e+00,
           -5.2691e-03, -1.5871e-01],
          [ 1.7015e-01,  4.2257e-01,  9.6964e-01,  ..., -2.5970e-01,
            3.9704e-01, -5.1129e-01],
          [ 7.7848e-01,  4.4700e-01,  4.3270e-01,  ...,  3.3544e-01,
            1.4105e+00,  6.6389e-01],
          ...,
          [-2.5888e-01,  1.2881e-01,  2.1639e-01,  ..., -1.2198e+00,
            1.0314e+00, -7.4611e-02],
          [ 7.9196e-01,  4.3873e-01,  8.0761e-01,  ...,  9.0232e-01,
            1.5996e+00,  8.0240e-01],
          [ 3.8602e-01,  5.4603e-01,  3.3816e-01,  ..., -6.9882e-01,
           -1.4301e+00, -4.7366e-01]],

         [[-2.9780e-01, -1.3218e+00, -2.6823e-01,  ..., -1.1618e+00,
           -2.9855e-01,  3.9646e-01],
          [ 3.1010e-01,  4.9516e-01,  1.7241e-01,  ..., -5.7428e-01,
           -1.1108e+00, -1.3312e+00],
          [-3.1486e-01,  7.0213e-01,  8.5001e-01,  ..., -1.0412e+00,
           -1.0719e+00, -5.6568e-01],
          ...,
          [-1.6898e-01, -7.2126e-01,  4.5796e-02,  ...,  7.2761e-01,
           -2.3792e-01, -7.9595e-01],
          [-1.7269e+00, -1.7630e+00, -9.0964e-01,  ..., -9.1442e-01,
            4.1963e-01,  1.3568e+00],
          [-4.8538e-01,  8.0793e-02,  3.3132e-01,  ..., -1.9279e-01,
           -4.4357e-01,  7.3314e-01]]]], grad_fn=<AddBackward0>)
>>> b
tensor([[[[ 1.1276e+00, -3.9896e-01,  1.3352e+00,  ...,  6.2086e-01,
           -2.6533e-01,  8.9948e-01],
          [ 7.6792e-01,  1.1115e+00,  9.9893e-01,  ...,  8.1904e-01,
            4.4017e-01, -1.1244e+00],
          [ 4.8875e-01,  1.3828e+00,  1.6243e+00,  ...,  5.2787e-01,
           -1.2887e+00, -1.8287e-01],
          ...,
          [ 5.3505e-01,  6.5837e-01,  8.6899e-01,  ..., -5.2909e-02,
           -2.3689e-01, -5.8535e-01],
          [ 8.2054e-01,  1.1339e+00,  6.3626e-01,  ...,  1.2077e+00,
           -1.3434e+00,  1.4043e-02],
          [-1.8604e-01,  9.4874e-01,  5.8726e-01,  ..., -2.1825e-02,
            2.2782e-01, -1.0425e+00]],

         [[-5.4216e-01,  6.5893e-02,  7.4183e-01,  ..., -8.3199e-01,
           -5.6210e-01, -1.2276e+00],
          [ 3.6446e-01,  3.3610e-01,  1.0743e+00,  ...,  1.3566e+00,
           -3.0510e-01,  6.2653e-01],
          [-3.5442e-03, -4.2382e-03,  1.0142e+00,  ...,  1.5751e-01,
           -3.4393e-01,  1.0233e+00],
          ...,
          [ 5.7268e-01,  1.6793e+00,  7.4196e-01,  ...,  2.3008e+00,
            1.2864e+00,  1.0479e+00],
          [ 1.0941e+00,  1.3378e+00,  6.1509e-01,  ...,  1.9829e+00,
            1.4082e+00,  1.2889e+00],
          [ 2.2067e-01,  9.9273e-01, -4.0986e-02,  ...,  4.1015e-01,
            1.0265e+00,  7.0073e-01]],

         [[ 9.2357e-01,  1.2548e+00,  1.4542e+00,  ...,  1.3107e+00,
           -4.0825e-01,  1.3620e+00],
          [ 1.5236e+00,  5.9024e-01,  1.1501e+00,  ...,  1.1650e+00,
            1.6763e+00,  4.5799e-01],
          [ 9.5630e-01,  1.3678e+00,  1.6815e+00,  ...,  1.8899e-01,
            5.9160e-01,  2.3954e+00],
          ...,
          [ 1.3852e+00,  1.3368e+00, -2.9295e-02,  ...,  7.4981e-01,
            9.3762e-01,  1.1983e+00],
          [ 1.0066e+00,  9.2232e-01,  1.3569e+00,  ...,  2.0778e+00,
            8.4387e-01,  3.0434e-01],
          [ 1.4036e+00,  1.1533e+00, -5.7359e-01,  ..., -1.0455e-03,
            1.1398e+00,  1.4820e+00]],

         ...,

         [[-1.0698e+00, -1.4772e-02, -2.9359e-01,  ...,  1.2239e+00,
           -1.2180e-01, -1.0637e+00],
          [-2.4025e-01, -1.2997e-01, -3.1515e-02,  ...,  7.6373e-01,
           -1.5430e-01, -6.8904e-01],
          [ 4.5413e-01,  5.3748e-01, -7.1436e-01,  ...,  9.6014e-01,
            9.0057e-03,  7.1679e-01],
          ...,
          [-4.5574e-01,  1.0993e+00, -7.2000e-02,  ...,  6.7207e-01,
            1.7814e+00, -1.5952e+00],
          [ 4.6145e-01, -4.6165e-01, -1.4222e-02,  ...,  2.4023e+00,
           -1.4463e+00, -8.7620e-01],
          [-9.7515e-01, -6.4810e-01,  6.0649e-01,  ..., -1.2477e+00,
           -3.0792e-01,  1.1256e+00]],

         [[ 9.7454e-01,  6.5678e-01,  9.7528e-01,  ...,  1.1065e+00,
           -5.0216e-03, -1.5847e-01],
          [ 1.7006e-01,  4.2306e-01,  9.6959e-01,  ..., -2.5992e-01,
            3.9714e-01, -5.1084e-01],
          [ 7.7854e-01,  4.4686e-01,  4.3235e-01,  ...,  3.3460e-01,
            1.4102e+00,  6.6323e-01],
          ...,
          [-2.5823e-01,  1.2859e-01,  2.1663e-01,  ..., -1.2196e+00,
            1.0320e+00, -7.4387e-02],
          [ 7.9208e-01,  4.3808e-01,  8.0768e-01,  ...,  9.0169e-01,
            1.5992e+00,  8.0307e-01],
          [ 3.8559e-01,  5.4567e-01,  3.3847e-01,  ..., -6.9920e-01,
           -1.4303e+00, -4.7335e-01]],

         [[-2.9799e-01, -1.3227e+00, -2.6864e-01,  ..., -1.1625e+00,
           -2.9797e-01,  3.9606e-01],
          [ 3.0952e-01,  4.9551e-01,  1.7227e-01,  ..., -5.7379e-01,
           -1.1105e+00, -1.3312e+00],
          [-3.1462e-01,  7.0210e-01,  8.5023e-01,  ..., -1.0416e+00,
           -1.0716e+00, -5.6509e-01],
          ...,
          [-1.6905e-01, -7.2131e-01,  4.6613e-02,  ...,  7.2798e-01,
           -2.3790e-01, -7.9638e-01],
          [-1.7270e+00, -1.7633e+00, -9.0924e-01,  ..., -9.1347e-01,
            4.2046e-01,  1.3563e+00],
          [-4.8477e-01,  8.0591e-02,  3.3117e-01,  ..., -1.9234e-01,
           -4.4338e-01,  7.3338e-01]]]], device='cuda:0',
       grad_fn=<AddBackward0>)
>>>