TrojanXu / onnxparser-trt-plugin-sample

A sample for onnxparser working with trt user defined plugins for TRT7.0
Apache License 2.0
166 stars 36 forks source link

Get inconsistent result with torch.nn.functional.grid_sample #16

Closed lhao0301 closed 2 years ago

lhao0301 commented 2 years ago

The grid_sample plugin has inconsistent result with torch.nn.functional.grid_sample.

npy_files: g_feat.npy: https://drive.google.com/file/d/18YPjH4R8GBew_LZdgziHwze2fMkaYvoP/view?usp=sharing g_grid.npy: https://drive.google.com/file/d/18tOWjSN-rPUA6odTBBloS7_ERDQ-_a2N/view?usp=sharing

Code to reproduce is as follows,

def torch2onnx():
    import torch
    import numpy as np
    class Debug(torch.nn.Module):
        def __init__(self):
            super(Debug, self).__init__()

        def forward(self, feat, grid):
            return torch.nn.functional.grid_sample(feat, grid, 'bilinear', 'zeros', True)

    net = Debug().cuda()
    net.eval()
    feat = torch.from_numpy(np.load('g_feat.npy')).cuda()
    grid = torch.from_numpy(np.load('g_grid.npy')).cuda()
    with torch.no_grad():
        torch.onnx.export(
            net, (feat, grid), 'gs.onnx', verbose=True,
            input_names=['feat', 'grid'], output_names=['corr'],
            opset_version=16, do_constant_folding=True)

def check_torch_and_trt():
    trt_file = 'gs.trt'
    engine = load_engine(trt_file)

    context = engine.create_execution_context()

    feat = torch.from_numpy(np.load('g_feat.npy')).cuda()
    grid = torch.from_numpy(np.load('g_grid.npy')).cuda()

    out_list = [torch.from_numpy(np.load(key+'.npy')).cuda() for key in ['g_out']]

    out_pred = [torch.empty_like(item) for item in out_list]

    stream = cuda.Stream()
    bindings = [feat.data_ptr(), grid.data_ptr()] + [item.contiguous().data_ptr() for item in out_pred]

    context.execute_async_v2(bindings=bindings, stream_handle=stream.handle)
    #stream.synchronize()

    torch_out = torch.nn.functional.grid_sample(feat, grid, 'bilinear', 'zeros', True)
    print(torch.max(torch.abs(torch_out - out_pred[0])))
lhao0301 commented 2 years ago

mistake!