hahnyuan / nn_tools

Neural Network Tools: Converter and Analyzer. For caffe, pytorch, draknet and so on.
MIT License
355 stars 63 forks source link

nn.BatchNorm2d accuracy #14

Closed adeagle closed 5 years ago

adeagle commented 5 years ago
from models import *
from utils.datasets import *
from utils.utils import *
import torch
from torch.autograd import Variable
from torchvision.models.densenet import *
import sys
from collections import OrderedDict

def generate_random(shape,gpu=False):
    data_np=np.random.rand(np.prod(shape)).reshape(shape)
    data_torch=Variable(torch.Tensor(data_np))
    if gpu:
        data_torch=data_torch.cuda()
    return [data_np],[data_torch]

def get_input_size(caffe_net):
    input_name = caffe_net.inputs[0]
    return caffe_net.blobs[input_name].data.shape

def forward_torch(net,data):
    blobs=OrderedDict()
    module2name={}
    for layer_name,m in net.named_modules():
        layer_name=layer_name.replace('.','_')
        module2name[m]=layer_name
        # turn off all the inplace operation
        if hasattr(m,'inplace'):
            m.inplace=False
    def forward_hook(m,i,o):
        o_np = o.data.cpu().numpy()
        blobs[module2name[m]]=o_np
    for m in net.modules():
        m.register_forward_hook(forward_hook)
    output=net.forward(*data)
    if isinstance(output,list):
        outputs=[]
        for o in output:
            outputs.append(o.data.cpu().numpy())
    else:
        outputs=[output.data.cpu().numpy()]
    return blobs,outputs

def forward_caffe(net,data):
    for input_name,d in zip(net.inputs,data):
        net.blobs[input_name].data[...] = d
    rst=net.forward()
    blobs=OrderedDict()
    blob2layer={}
    for layer_name,tops in net.top_names.items():
        for top in tops:
            blob2layer[top]=layer_name
    for name,value in net.blobs.items():
        layer_name=blob2layer[name]
        value=value.data
        if layer_name in blobs:
            blobs[layer_name].append(value)
        else:
            blobs[layer_name]=[value]
    outputs = []
    for output_name in net.outputs:
        outputs.append(rst[output_name])
    return blobs,outputs

def test(net_caffe,net_torch,data_np,data_torch):
    blobs_caffe, rsts_caffe = forward_caffe(net_caffe, data_np)
    blobs_torch, rsts_torchs = forward_torch(net_torch, data_torch)
    # test the output of every layer
    for layer, value in blobs_caffe.items():
        if layer in blobs_torch:
            value_torch = blobs_torch[layer]
            value = value[0]
            if value.size!=value_torch.size:continue
            if 'relu' in layer: continue
            try:
                np.testing.assert_almost_equal(value, value_torch, decimal=2)
                print("TEST layer {}: PASS".format(layer))
            except:
                print("TEST layer {}: FAIL".format(layer))
                np.testing.assert_almost_equal(np.clip(value, min=0), np.clip(value_torch, min=0))
    # test the output
    #print("TEST output")
    #for rst_caffe,rst_torch in zip(rsts_caffe,rsts_torchs):
    #    np.testing.assert_almost_equal(rst_caffe, rst_torch, decimal=args.decimal)
    #print("TEST output: PASS")

class CNN(nn.Module):
    def __init__(self):
        super(CNN, self).__init__()
        self.conv1 = nn.Sequential(
            nn.Conv2d(in_channels=3,
                      out_channels=16,
                     kernel_size=3,
                     stride=1,
                     padding=1
                     ),
            nn.BatchNorm2d(16),
            nn.ReLU(),
        )

    def forward(self, x):
        x = self.conv1(x)
        x = x.view(x.size(0), -1)
        return x

def trans():
    def trans_imp(model):
        import pytorch_to_caffe
        name="demo"
        input=Variable(torch.ones([1,3,608,608]))
        pytorch_to_caffe.trans_net(model,input,name)
        pytorch_to_caffe.save_prototxt('{}.prototxt'.format(name))
        pytorch_to_caffe.save_caffemodel('{}.caffemodel'.format(name))

    model = CNN()
    input=Variable(torch.ones([1,3,608,608]))

    model.eval()

    checkpoint = {'model': model.state_dict()}
    torch.save(checkpoint, "model.pt")
    trans_imp(model)
    print("trans success")
def verfiy():   
    model = CNN()
    model.train()
    model.load_state_dict(torch.load("model.pt", map_location='cpu')['model'])

    import caffe

    net_caffe = caffe.Net('demo.prototxt', 'demo.caffemodel', caffe.TEST)
    shape=get_input_size(net_caffe)
    print(shape)
    data_np,data_torch=generate_random(shape)
    test(net_caffe,model,data_np,data_torch)

if __name__=='__main__':    
    if len(sys.argv)==2:
        trans()
    else:
        verfiy()