xxradon / PytorchToCaffe

Pytorch model to caffe model, supported pytorch 0.3, 0.3.1, 0.4, 0.4.1 ,1.0 , 1.0.1 , 1.2 ,1.3 .notice that only pytorch 1.1 have some bugs
MIT License
783 stars 224 forks source link

请问如何支持F.relu6的转换 #44

Closed Dawson-huang closed 4 years ago

Dawson-huang commented 5 years ago

博主您好,我尝试过添加代码:

def _relu6(raw, input, inplace=False):
    # for F.relu6
    x = raw(input, inplace)
    bottom_blobs=[log.blobs(input)]
    name = log.add_layer(name='relu6')
    top_blobs=log.add_blobs([x],name=bottom_blobs[0],with_num=False)
    layer = caffe_net.Layer_param(name=name, type='ReLU6',
                                  bottom=bottom_blobs,top=top_blobs)
    log.cnet.add_layer(layer)
    return x

或者

def _relu6(raw, input, inplace=False):
    # for threshold or prelu
    x = raw(input, False)
    name = log.add_layer(name='relu6')
    log.add_blobs([x], name='relu6_blob')
    layer = caffe_net.Layer_param(name=name, type='ReLU6',
                                  bottom=[log.blobs(input)], top=[log.blobs(x)])
    log.cnet.add_layer(layer)
    return x

最后在下面添加了一行:

F.relu6=Rp(F.relu6,_relu6)

都依然会报错:

139789595636720:add_blob1 was added to blobs
Add blob       add_blob1       : torch.Size([1, 16, 112, 112])
139789595637440:batch_norm_blob1 getting
Traceback (most recent call last):
  File "example/mobilenet_pytorch_to_caffe.py", line 20, in <module>
    pytorch_to_caffe.trans_net(net, input, name)
  File "/home/sonny/PytorchToCaffe/pytorch_to_caffe.py", line 658, in trans_net
    out = net.forward(input_var)
  File "mobilenet/mobilenet.py", line 182, in forward
    out = self.hs1(self.bn1(self.conv1(x)))
  File "/usr/local/lib/python3.6/dist-packages/torch/nn/modules/module.py", line 493, in __call__
    result = self.forward(*input, **kwargs)
  File "mobilenet/mobilenet.py", line 15, in forward
    out = x * F.relu6(x + 3, inplace=True) / 6
  File "/home/sonny/PytorchToCaffe/pytorch_to_caffe.py", line 486, in _add
    bottom=[log.blobs(input),log.blobs(args[0])], top=top_blobs)
  File "/home/sonny/PytorchToCaffe/pytorch_to_caffe.py", line 88, in blobs
    print("{}:{} getting".format(var, self._blobs[var]))
  File "/home/sonny/PytorchToCaffe/pytorch_to_caffe.py", line 31, in __getitem__
    return self.data[key]
KeyError: 10914560

博主知道怎么解决吗?求教

xxradon commented 4 years ago

最新的代码已经添加了Relu6的实现。

dreamhighchina commented 4 years ago

大佬你这个是mobilenet几??