gmalivenko / pytorch2keras

PyTorch to Keras model convertor
https://pytorch2keras.readthedocs.io/en/latest/
MIT License
858 stars 143 forks source link

pytorch version model transfer :AttributeError: Unsupported number of inputs #106

Open lipeng1109 opened 4 years ago

lipeng1109 commented 4 years ago

Describe the bug pytorch version of craft to keras model:

WARNING:onnx2keras:upsample:!!! EXPERIMENTAL SUPPORT (upsample) !!! Traceback (most recent call last): File "test.py", line 190, in k_model = converter.pytorch_to_keras(net, input_var, [(3, 224, 224,)], verbose=True,name_policy='short') File "/root/anaconda3/lib/python3.6/site-packages/pytorch2keras/converter.py", line 73, in pytorch_to_keras verbose=verbose, change_ordering=change_ordering) File "/root/anaconda3/lib/python3.6/site-packages/onnx2keras/converter.py", line 177, in onnx_to_keras keras_names File "/root/anaconda3/lib/python3.6/site-packages/onnx2keras/upsampling_layers.py", line 21, in convert_upsample raise AttributeError('Unsupported number of inputs') AttributeError: Unsupported number of inputs

To Reproduce net = CRAFT() # initialize

print('Loading weights from checkpoint (' + args.trained_model + ')')
if args.cuda:
    net.load_state_dict(copyStateDict(torch.load(args.trained_model)))
else:
    net.load_state_dict(copyStateDict(torch.load(args.trained_model, map_location='cpu')))

if args.cuda:
    net = net.cuda()
    net = torch.nn.DataParallel(net)
    cudnn.benchmark = False

net.eval()

# LinkRefiner
refine_net = None
if args.refine:
    from refinenet import RefineNet

    refine_net = RefineNet()
    print('Loading weights of refiner from checkpoint (' + args.refiner_model + ')')
    if args.cuda:
        refine_net.load_state_dict(copyStateDict(torch.load(args.refiner_model)))
        refine_net = refine_net.cuda()
        refine_net = torch.nn.DataParallel(refine_net)
    else:
        refine_net.load_state_dict(copyStateDict(torch.load(args.refiner_model, map_location='cpu')))

    refine_net.eval()
    args.poly = True

t = time.time()

# load data
for k, image_path in enumerate(image_list):
    print("Test image {:d}/{:d}: {:s}".format(k + 1, len(image_list), image_path), end='\r')
    image = imgproc.loadImage(image_path)

    print("模型转换开始")
    ## 处理模型转换
    from pytorch2keras import converter
    # resize
    img_resized, target_ratio, size_heatmap = imgproc.resize_aspect_ratio(image, args.canvas_size,
                                                                          interpolation=cv2.INTER_LINEAR,
                                                                          mag_ratio=args.mag_ratio)
    ratio_h = ratio_w = 1 / target_ratio
    # preprocessing
    x = imgproc.normalizeMeanVariance(img_resized)
    x = torch.from_numpy(x).permute(2, 0, 1)  # [h, w, c] to [c, h, w]
    print(x.shape)
    input_var = Variable(x.unsqueeze(0))  # [c, h, w] to [b, c, h, w]
    print(input_var.shape)
    # net(input_var)
    # input_np = np.random.uniform(0, 1, (1, 3, 224, 224))
    # input_var = Variable(torch.FloatTensor(input_np))
    # print(type(input_var))
    # print(input_var)
    k_model = converter.pytorch_to_keras(net, input_var, [(3, 224, 224,)], verbose=True)
    k_model.summary()
    # 保存模型
    k_model.save('my_model.h5')
    print("模型转换结束")
    bboxes, polys, score_text = ctest_net(net, image, args.text_threshold, args.link_threshold, args.low_text,
                                          args.cuda, args.poly, refine_net)

    # save score text
    filename, file_ext = os.path.splitext(os.path.basename(image_path))
    mask_file = result_folder + "/res_" + filename + '_mask.jpg'
    cv2.imwrite(mask_file, score_text)
    image = image.copy()
    file_utils.saveResult(image_path, image[:, :, ::-1], polys, dirname=result_folder)

print("elapsed time : {}s".format(time.time() - t))

Expected behavior A clear and concise description of what you expected to happen.

Logs If applicable, add error message to help explain your problem.

Environment (please complete the following information):

Additional context Add any other context about the problem here.