from models import *
from utils.datasets import *
from utils.utils import *
import torch
from torch.autograd import Variable
from torchvision.models.densenet import *
import sys
from collections import OrderedDict
def generate_random(shape,gpu=False):
data_np=np.random.rand(np.prod(shape)).reshape(shape)
data_torch=Variable(torch.Tensor(data_np))
if gpu:
data_torch=data_torch.cuda()
return [data_np],[data_torch]
def get_input_size(caffe_net):
input_name = caffe_net.inputs[0]
return caffe_net.blobs[input_name].data.shape
def forward_torch(net,data):
blobs=OrderedDict()
module2name={}
for layer_name,m in net.named_modules():
layer_name=layer_name.replace('.','_')
module2name[m]=layer_name
# turn off all the inplace operation
if hasattr(m,'inplace'):
m.inplace=False
def forward_hook(m,i,o):
o_np = o.data.cpu().numpy()
blobs[module2name[m]]=o_np
for m in net.modules():
m.register_forward_hook(forward_hook)
output=net.forward(*data)
if isinstance(output,list):
outputs=[]
for o in output:
outputs.append(o.data.cpu().numpy())
else:
outputs=[output.data.cpu().numpy()]
return blobs,outputs
def forward_caffe(net,data):
for input_name,d in zip(net.inputs,data):
net.blobs[input_name].data[...] = d
rst=net.forward()
blobs=OrderedDict()
blob2layer={}
for layer_name,tops in net.top_names.items():
for top in tops:
blob2layer[top]=layer_name
for name,value in net.blobs.items():
layer_name=blob2layer[name]
value=value.data
if layer_name in blobs:
blobs[layer_name].append(value)
else:
blobs[layer_name]=[value]
outputs = []
for output_name in net.outputs:
outputs.append(rst[output_name])
return blobs,outputs
def test(net_caffe,net_torch,data_np,data_torch):
blobs_caffe, rsts_caffe = forward_caffe(net_caffe, data_np)
blobs_torch, rsts_torchs = forward_torch(net_torch, data_torch)
# test the output of every layer
for layer, value in blobs_caffe.items():
if layer in blobs_torch:
value_torch = blobs_torch[layer]
value = value[0]
if value.size!=value_torch.size:continue
if 'relu' in layer: continue
try:
np.testing.assert_almost_equal(value, value_torch, decimal=2)
print("TEST layer {}: PASS".format(layer))
except:
print("TEST layer {}: FAIL".format(layer))
np.testing.assert_almost_equal(np.clip(value, min=0), np.clip(value_torch, min=0))
# test the output
#print("TEST output")
#for rst_caffe,rst_torch in zip(rsts_caffe,rsts_torchs):
# np.testing.assert_almost_equal(rst_caffe, rst_torch, decimal=args.decimal)
#print("TEST output: PASS")
class CNN(nn.Module):
def __init__(self):
super(CNN, self).__init__()
self.conv1 = nn.Sequential(
nn.Conv2d(in_channels=3,
out_channels=16,
kernel_size=3,
stride=1,
padding=1
),
nn.BatchNorm2d(16),
nn.ReLU(),
)
def forward(self, x):
x = self.conv1(x)
x = x.view(x.size(0), -1)
return x
def trans():
def trans_imp(model):
import pytorch_to_caffe
name="demo"
input=Variable(torch.ones([1,3,608,608]))
pytorch_to_caffe.trans_net(model,input,name)
pytorch_to_caffe.save_prototxt('{}.prototxt'.format(name))
pytorch_to_caffe.save_caffemodel('{}.caffemodel'.format(name))
model = CNN()
input=Variable(torch.ones([1,3,608,608]))
model.eval()
checkpoint = {'model': model.state_dict()}
torch.save(checkpoint, "model.pt")
trans_imp(model)
print("trans success")
def verfiy():
model = CNN()
model.train()
model.load_state_dict(torch.load("model.pt", map_location='cpu')['model'])
import caffe
net_caffe = caffe.Net('demo.prototxt', 'demo.caffemodel', caffe.TEST)
shape=get_input_size(net_caffe)
print(shape)
data_np,data_torch=generate_random(shape)
test(net_caffe,model,data_np,data_torch)
if __name__=='__main__':
if len(sys.argv)==2:
trans()
else:
verfiy()