Open BobbyXue666 opened 1 year ago
麻烦提供一下MegEngine的版本,以及最小复现,就目前的信息很难定位问题。
MegEngine 1.12.1+cu114 megfile 2.0.0.post2 最小复现如下:
class TestModel(M.Module):
def __init__(self,):
super().__init__()
self.conv1 = M.Conv2d(2,16,kernel_size=3, stride=2, padding=1, bias=False)
self.bn1 = M.BatchNorm2d(16)
self.relu1 = M.ReLU()
self.conv2 = M.Conv2d(16,4,kernel_size=3, stride=2, padding=1, bias=False)
self.bn2 = M.BatchNorm2d(4)
self.relu2 = M.ReLU()
self.conv3 = M.Conv2d(4,1,kernel_size = 5, stride=1, padding=0, bias=False )
self.ac3 = M.Sigmoid()
def forward(self, x):
out = self.relu1(self.bn1(self.conv1(x)))
out = self.relu2(self.bn2(self.conv2(out)))
out = self.ac3(self.conv3(out))
return out
model = TestModel()
model.eval()
data = tensor(np.random.random([1, 2, 20, 20]).astype(np.float32))
print(model(data))
@jit.trace(symbolic=True, capture_as_const=True)
def fun(data, *, net):
pred = net(data)
return pred
fun(data, net=model)
fun.dump("test.mge", arg_names=["data"],optimize_for_inference=False, enable_fuse_conv_bias_nonlinearity=False, enable_fuse_conv_bias_with_z=False)
用你上面的代码和对应的MegEngine的版本,我这边测试没有问题的。
这是我的json文件
{
"dump_dir":"./batch_dump/",
"models":[
{
"model_name":"test",
"model_path":"./example/test.mge",
"input_shape_str":"data=(1,2,20,20)",
"enable_nchw44":true
}
],
"cv":{
"transpose":["ui8"],
"roicopy":["ui8"],
"rotate":["ui8"],
"flip":["ui8"],
"resize_linear":["ui8"],
"warp_affine_replicate_linear":["ui8"],
"rgb2bgr":["ui8"],
"yuv2bgr_nv21":["ui8"],
"rgb2yuv":["ui8"]
}
}
这是生成模型的命令
./script/ppl_gen.sh ./bin/mgb-to-tinynn ./example/demo.json test --arm64
将bias去掉可以编过,想知道是为什么
self.conv1 = M.Conv2d(2,16,kernel_size=3, stride=2, padding=1)
@BobbyXue666 optimize_for_inference=False 这个写成 true 试试
dump代码如下 @jit.trace(symbolic=True, capture_as_const=True) def fun(data, *, net): pred = net(data) return pred
fun(data, net=model) fun.dump("test.mge", arg_names=["data"],optimize_for_inference=False, enable_fuse_conv_bias_nonlinearity=False, enable_fuse_conv_bias_with_z=False)