xuebinqin / U-2-Net

The code for our newly accepted paper in Pattern Recognition 2020: "U^2-Net: Going Deeper with Nested U-Structure for Salient Object Detection."
Apache License 2.0
8.31k stars 1.43k forks source link

the effect of replacing background with U-2-Net is not as expected, anyone can offer a help #377

Open weiweiwang opened 5 months ago

weiweiwang commented 5 months ago

Env

model: u2net downloaded from the link in the repository README:https://pan.baidu.com/s/1WjwyEwDiaUjBbx_QxcXBwQ

Test input

image: https://cdn.avatar.dmc-ai.cn/avatar/photos/2024/02/02/gJT2UhwWmcHgM6ep.jpg background: https://cdn.avatar.dmc-ai.cn/avatar/photos/2024/02/02/DCAqTPHo7Advhmrv.jpg current result: https://cdn.avatar.dmc-ai.cn/avatar/photos/2024/02/02/RFD5jxgohUXX4fAX.jpg

Test Method

  1. I modified the u2net_test.py(as below) and place the image in the folder: test_images
  2. write replace background images to folder: test_data/rbg
def main():
    # --------- 1. get image path and name ---------
    model_name = 'u2net'  # u2netp
    # model_name = 'u2netp'  # u2netp

    image_dir = os.path.join(os.getcwd(), 'test_data', 'test_images')
    prediction_dir = os.path.join(os.getcwd(), 'test_data', model_name + '_results' + os.sep)
    model_dir = os.path.join(os.getcwd(), 'saved_models', model_name, model_name + '.pth')

    img_name_list = glob.glob(image_dir + os.sep + '*')
    print(img_name_list)

    # --------- 2. dataloader ---------
    # 1. dataloader
    test_salobj_dataset = SalObjDataset(img_name_list=img_name_list,
                                        lbl_name_list=[],
                                        transform=transforms.Compose([RescaleT(320),
                                                                      ToTensorLab(flag=0)])
                                        )
    test_salobj_dataloader = DataLoader(test_salobj_dataset,
                                        batch_size=1,
                                        shuffle=False,
                                        num_workers=1)

    # --------- 3. model define ---------
    if (model_name == 'u2net'):
        print("...load U2NET---173.6 MB")
        net = U2NET(3, 1)
    elif (model_name == 'u2netp'):
        print("...load U2NEP---4.7 MB")
        net = U2NETP(3, 1)

    if torch.cuda.is_available():
        net.load_state_dict(torch.load(model_dir))
        net.cuda()
    else:
        net.load_state_dict(torch.load(model_dir, map_location='cpu'))
    net.eval()

    # --------- 4. inference for each image ---------
    for i_test, data_test in enumerate(test_salobj_dataloader):

        print("inferencing:", img_name_list[i_test].split(os.sep)[-1])

        inputs_test = data_test['image']
        inputs_test = inputs_test.type(torch.FloatTensor)

        if torch.cuda.is_available():
            inputs_test = Variable(inputs_test.cuda())
        else:
            inputs_test = Variable(inputs_test)

        d1, d2, d3, d4, d5, d6, d7 = net(inputs_test)

        # normalization
        pred = d1[:, 0, :, :]
        pred = normPRED(pred)

        # save results to test_results folder
        if not os.path.exists(prediction_dir):
            os.makedirs(prediction_dir, exist_ok=True)

        output_filename = save_output(img_name_list[i_test], pred, prediction_dir)

        ########## modification comes here ##########
        predict = pred
        predict = predict.squeeze()
        predict_np = predict.cpu().data.numpy()
        input_image_file_path = img_name_list[i_test]
        image = cv2.imread(input_image_file_path)
        background = cv2.imread("test_data/bg-05.jpg")
        background = cv2.resize(background, (image.shape[1], image.shape[0]))
        im = Image.fromarray(predict_np * 255).convert('RGB')
        imo = im.resize((image.shape[1], image.shape[0]), resample=Image.BILINEAR)
        data = np.asarray(imo, dtype="int32")
        condition = data > 0.98 * 255
        output_image = np.where(condition, image, background)
        cv2.imwrite(f"test_data/rbg/{os.path.basename(input_image_file_path)}", output_image)

        del d1, d2, d3, d4, d5, d6, d7

I'm a newbie at this field, could anyone offer a help, thanks a lot~