vitoralbiero / img2pose

The official PyTorch implementation of img2pose: Face Alignment and Detection via 6DoF, Face Pose Estimation - CVPR 2021
Other
588 stars 109 forks source link

ONNX output is giving incorrect DOF values #77

Open AbrarAdnan opened 1 year ago

AbrarAdnan commented 1 year ago

Hi Vítor, I tried to convert this model to onnx. When I tried running inference using this model, the bounding box and accuracy was very similar and within acceptable range but when I tried to run the align function the output was very distorted. And I found out that the DOF output was giving weird output. Here's an example of the onnx model dof [0.371725 -0.275935 1.866960 0.021321 -0.280082 0.334366] main model dof [0.083832 -0.089648 0.311040 0.715069 -0.212262 10.305621]

I'm not sure what could cause this problem. I modified the run_face_align.py file to convert the instance of the model to onnx and also infer in the same file. I'm pasting the contents of the file here for additional context. You can call the align function by passing different parameters to run different functions. I tried to modify the aligning with something else in the "aligning_faces_onnx_new" section. That one was purely by luck though but I'm interested finding the problem and fixing the onnx output. Also here's my conda env info in short torch: 1.7.1, python: 3.9.16, onnx: 1.14.0,

Let me know if you have any further questions. Thank you

modified_run_face_align.py

def image_preprocess(img_path):
    image_name = os.path.split(img_path)[-1]
    cv_image = cv2.imread(img_path)
    # cv_image = cv2.resize(cv_image,(300,440), interpolation = cv2.INTER_AREA) #???
    image_rgb = cv2.cvtColor(cv_image,cv2.COLOR_BGR2RGB)
    img = image_rgb
    img = img.transpose(2, 0, 1)
    image_resolution = (cv_image.shape[1], cv_image.shape[0])
    # print(f'Image resolution = {image_resolution}')
    # cv2.imwrite('preprocessed_' + image_name, cv_image)

    return [np.float32(img/255.0)]

def image_preprocess_dof(img_path):
    image_name = os.path.split(img_path)[-1]
    cv_image = cv2.imread(img_path)
    # cv_image = cv2.resize(cv_image,(300,440), interpolation = cv2.INTER_AREA) #???
    image_rgb = cv2.cvtColor(cv_image,cv2.COLOR_BGR2RGB)
    img = image_rgb
    img = img.transpose(2, 0, 1)

    # # Save preprocessed image
    # cv2.imwrite('preprocessed_' + image_name, cv_image)
    # print(img.shape)

    # Convert preprocessed image to PIL Image object
    # pil_image = Image.fromarray((img[0]*255).astype(np.uint8).transpose(1, 2, 0))
    pil_image = Image.fromarray((np.transpose(img[0], (1, 2, 0))*255).astype(np.uint8))
    return pil_image

def render_plot(img, poses, bboxes, output_file):
    (w, h) = img.size
    image_intrinsics = np.array([[w + h, 0, w // 2], [0, w + h, h // 2], [0, 0, 1]])

    # trans_vertices = renderer.transform_vertices(img, poses)
    # img = renderer.render(img, trans_vertices, alpha=1)    

    fig, ax = plt.subplots(figsize=(8, 8))     

    for bbox in bboxes:
        ax.add_patch(patches.Rectangle((bbox[0], bbox[1]), bbox[2] - bbox[0], bbox[3] - bbox[1], linewidth=3, edgecolor='b', facecolor='none'))            

    ax.imshow(img)
    plt.savefig(output_file)
    plt.show()
# End of new functions

class img2pose:
    def __init__(self, args):
        self.threed_5_points = np.load(args.threed_5_points)
        self.threed_68_points = np.load(args.threed_68_points)
        self.nms_threshold = args.nms_threshold

        self.pose_mean = np.load(args.pose_mean)
        self.pose_stddev = np.load(args.pose_stddev)
        self.model = self.create_model(args)

        self.transform = transforms.Compose([transforms.ToTensor()])
        self.min_size = (args.min_size,)
        self.max_size = args.max_size

        self.max_faces = args.max_faces
        self.face_size = args.face_size
        self.order_method = args.order_method
        self.det_threshold = args.det_threshold

        images_path = args.images_path
        if os.path.isfile(images_path):
            self.image_list = pd.read_csv(images_path, delimiter=" ", header=None)
            self.image_list = np.asarray(self.image_list).squeeze()
        else:
            self.image_list = [
                os.path.join(images_path, img_path)
                for img_path in os.listdir(images_path)
            ]

        self.output_path = args.output_path

    def create_model(self, args):
        img2pose_model = img2poseModel(
            args.depth,
            args.min_size,
            args.max_size,
            pose_mean=self.pose_mean,
            pose_stddev=self.pose_stddev,
            threed_68_points=self.threed_68_points,
        )
        load_model(
            img2pose_model.fpn_model,
            args.pretrained_path,
            cpu_mode=str(img2pose_model.device) == "cpu",
            model_only=True,
        )
        img2pose_model.evaluate()

        return img2pose_model

    def align(self, export_onnx, use_onnx_bbox, use_main_bbox, aligning_faces_main, aligning_faces_onnx_new,aligning_faces_onnx_old):
        onnx_model_path = "models/img2pose_cpu.onnx"
        sess = onnxruntime.InferenceSession(onnx_model_path)
        # total_dof_diff = []
        # with open('adnan_dof_onnx.csv', 'w', newline='') as csvfile:
        #     onnx_writer = csv.writer(csvfile)
        #     onnx_writer.writerow(['image_name', 'Rx', 'Ry', 'Rz', 'Tx', 'Ty', 'Tz'])
        # with open('adnan_dof_main.csv', 'w', newline='') as csvfile:
        #     main_writer = csv.writer(csvfile)
        #     main_writer.writerow(['image_name','Rx', 'Ry', 'Rz', 'Tx', 'Ty', 'Tz'])

        for img_path in tqdm(self.image_list):
                # print(img_path)
                image_name = os.path.split(img_path)[-1]
                img = Image.open(img_path).convert("RGB")
                onnx_output = sess.run(None, {'input': image_preprocess(img_path)})
                res = self.model.predict([self.transform(img)])[0]
                # print([self.transform(img)])
                # print(type(self.model.fpn_model.module))
                if export_onnx:
                    print('Exporting Onnx')
                    # input_shape = (1, 3, 224, 224)
                    dummy_input = torch.randn(1, 3, 800, 1333)
                    # dummy_input = torch.ones(1, 3, 224, 224).cuda()
                    # print("img preprocess for onnx")
                    # print(image_preprocess(img_path))
                    # print("Main data for dummy input")
                    # print(dummy_input)
                    dynamic_axes = {
                        'input':
                        {
                            0:'batch',
                            2:'width',
                            3:'height'
                            },
                        'output': 
                        {
                            0:'batch',
                            1:'detections'
                            }
                        }

                    # Convert the PyTorch model to ONNX format
                    torch.onnx.export(self.model.fpn_model.module, 
                                    dummy_input, 
                                    onnx_model_path,
                                    opset_version=12, #11
                                    export_params=True,
                                    verbose = True,
                                    input_names=['input'],
                                    output_names=['output'],
                                    dynamic_axes=dynamic_axes
                                    )
                    print('Exporting onnx done')
                    return
                else:
                    print('Not exporting onnx')

                if use_onnx_bbox:
                    print('using onnx bbox')
                    sess = onnxruntime.InferenceSession(onnx_model_path)
                    onnx_output = sess.run(None, {'input': image_preprocess(img_path)})
                    print(img_path)
                    # print(image_preprocess(img_path))
                    # print(onnx_output)

                    all_bboxes = onnx_output[0]

                    poses = []
                    bboxes = []
                    for i in range(len(all_bboxes)):
                        # if res["scores"][i] > self.det_threshold:
                        if onnx_output[2][i] > self.det_threshold:
                            bbox = all_bboxes[i]
                            # pose_pred = res["dofs"].cpu().numpy()[i].astype('float')
                            pose_pred = onnx_output[3]
                            pose_pred = pose_pred.squeeze()

                            poses.append(pose_pred)  
                            bboxes.append(bbox)
                    render_plot(img.copy(), poses, bboxes, output_file = 'results/onnx/'+'img2pose_v1'+image_name)
                else:
                    print('Not using onnx bbox')

                if use_main_bbox:
                    print('using main bbox')

                    res = self.model.predict([self.transform(img)])[0]
                    # print('Output of the main model [res]')
                    # print(res)

                    all_bboxes = res["boxes"].cpu().numpy().astype('float')

                    poses = []
                    bboxes = []
                    for i in range(len(all_bboxes)):
                        if res["scores"][i] > self.det_threshold:
                            bbox = all_bboxes[i]
                            pose_pred = res["dofs"].cpu().numpy()[i].astype('float')
                            pose_pred = pose_pred.squeeze()

                            poses.append(pose_pred)  
                            bboxes.append(bbox)
                    render_plot(img.copy(), poses, bboxes, 'results/main/'+'img2pose_v1'+image_name)
                else:
                    print('Not using main bbox')

                if aligning_faces_onnx_new:
                    # print('unfiltered onnx output')
                    # print(onnx_output)
                    # get the indices of boxes with confidence score greater than or equal to the threshold
                    high_confidence_indices = np.where(onnx_output[2] >= self.det_threshold)

                    # filter out the lower confidence scores and replace them with higher scores
                    boxes = onnx_output[0][high_confidence_indices]
                    labels = onnx_output[1][high_confidence_indices]
                    scores = onnx_output[2][high_confidence_indices]
                    dofs = onnx_output[3][high_confidence_indices]

                    # create a dictionary to store the filtered results
                    filtered_output = {'boxes': boxes, 'labels': labels, 'scores': scores, 'dofs': dofs}

                    print('filtered output')
                    print(filtered_output)
                    # print(filtered_output['dofs'][0])
                    # print(filtered_output['dofs'][0][2])

                    # This needs to be edited for multiple faces in the images
                    for i, bbox in enumerate(filtered_output['boxes']):
                        cropped_img = img.crop(bbox)
                        print(len(bbox))

                        pose_info = filtered_output['dofs'][i]
                        angle = float(pose_info[2] * 10)
                        rotated_img = TF.rotate(cropped_img, angle)
                        name, ext = image_name.split(".")
                        save_name = f"onnx_align_old_{name}_{i}.{ext}"

                        # Generate a unique save name for each aligned face using the face index
                        rotated_img.save(os.path.join(args.output_path, save_name))
                if aligning_faces_onnx_old:
                    all_bboxes = onnx_output[0]

                    # work with original
                    all_scores = onnx_output[2]
                    all_poses = onnx_output[3]
                    # print('all poses data from main')
                    # print(all_poses)

                    all_poses = all_poses[all_scores > self.det_threshold]
                    all_scores = all_scores[all_scores > self.det_threshold]

                    if len(all_poses) > 0:
                        if self.order_method == "confidence":
                            order = np.argsort(all_scores)[::-1]

                        elif self.order_method == "position":
                            distance_center = np.sqrt(
                                all_poses[:, 3] ** 2
                                + all_poses[:, 4] ** 2
                            )

                            order = np.argsort(distance_center)

                        top_poses = all_poses[order][: self.max_faces]

                        sub_folder = os.path.basename(
                            os.path.normpath(os.path.split(img_path)[0])
                        )
                        output_path = os.path.join(args.output_path, sub_folder)
                        if not os.path.exists(output_path):
                            os.makedirs(output_path)

                        for i in range(len(top_poses)):
                            save_name = f'onnx_align_old{image_name}'

                            if len(top_poses) > 1:
                                name, ext = image_name.split(".")
                                save_name = f"onnx_align_old{name}_{i}.{ext}"

                            print('Top poses onnx[i]')
                            print(top_poses[i])
                            aligned_face = align_faces(self.threed_5_points, img, top_poses[i])[
                                0
                            ]
                            aligned_face = aligned_face.resize((self.face_size, self.face_size))
                            aligned_face.save(os.path.join(output_path, save_name))
                    else:
                        print(f"No face detected above the threshold {self.det_threshold}!")

                if aligning_faces_main:
                    all_bboxes = res["boxes"].cpu().numpy().astype('float')
                    # print(all_bboxes)

                    # work with original
                    all_scores = res["scores"].cpu().numpy().astype("float")
                    all_poses = res["dofs"].cpu().numpy().astype("float")
                    # print('all poses data from main')
                    # print(all_poses)

                    all_poses = all_poses[all_scores > self.det_threshold]
                    all_scores = all_scores[all_scores > self.det_threshold]
                    # print('threshold 6dof')
                    # print(all_poses)

                    if len(all_poses) > 0:
                        if self.order_method == "confidence":
                            order = np.argsort(all_scores)[::-1]

                        elif self.order_method == "position":
                            distance_center = np.sqrt(
                                all_poses[:, 3] ** 2
                                + all_poses[:, 4] ** 2
                            )

                            order = np.argsort(distance_center)

                        top_poses = all_poses[order][: self.max_faces]

                        sub_folder = os.path.basename(
                            os.path.normpath(os.path.split(img_path)[0])
                        )
                        output_path = os.path.join(args.output_path, sub_folder)
                        if not os.path.exists(output_path):
                            os.makedirs(output_path)

                        for i in range(len(top_poses)):
                            save_name = 'main_align_'+image_name

                            if len(top_poses) > 1:
                                name, ext = image_name.split(".")
                                save_name = f"main_align_{name}_{i}.{ext}"

                            print('Top poses main[i]')
                            print(top_poses[i])
                            file_name, ext = image_name.split(".")
                            top_poses_list = top_poses[i].tolist()
                            row_values = [file_name]+top_poses_list
                            # print(row_values)
                            # main_writer.writerow(row_values)
                            aligned_face = align_faces(self.threed_5_points, img, top_poses[i])[
                                0
                            ]
                            aligned_face = aligned_face.resize((self.face_size, self.face_size))
                            aligned_face.save(os.path.join(output_path, save_name))
                    else:
                        print(f"No face detected above the threshold {self.det_threshold}!")

def parse_args():
    parser = argparse.ArgumentParser(
        description="Align top n faces ordering by score or distance to image center."
    )
    parser.add_argument("--max_faces", help="Top n faces to save.", default=10, type=int)
    parser.add_argument(
        "--order_method",
        help="How to order faces [confidence, position].",
        default="confidence",#default position
        type=str,
    )
    parser.add_argument(
        "--face_size",
        help="Image size to save aligned faces [112 or 224].",
        default=224,
        type=int,
    )
    parser.add_argument("--min_size", help="Image min size", default=300, type=int) #400 #300
    parser.add_argument("--max_size", help="Image max size", default=1440, type=int) #1400 # 440
    parser.add_argument(
        "--depth", help="Number of layers [18, 50 or 101].", default=18, type=int
    )
    parser.add_argument(
        "--pose_mean",
        help="Pose mean file path.",
        type=str,
        default="./models/WIDER_train_pose_mean_v1.npy",
    )
    parser.add_argument(
        "--pose_stddev",
        help="Pose stddev file path.",
        type=str,
        default="./models/WIDER_train_pose_stddev_v1.npy",
    )

    parser.add_argument(
        "--pretrained_path",
        help="Path to pretrained weights.",
        type=str,
        default="./models/img2pose_v1.pth"
        # default="./models/img2pose_v1_ft_300w_lp.pth"
    )

    parser.add_argument(
        "--threed_5_points",
        type=str,
        help="Reference 3D points to align the face.",
        default="./pose_references/reference_3d_5_points_trans.npy",
    )

    parser.add_argument(
        "--threed_68_points",
        type=str,
        help="Reference 3D points to project bbox.",
        default="./pose_references/reference_3d_68_points_trans.npy",
    )

    parser.add_argument("--nms_threshold", default=0.6, type=float)
    parser.add_argument(
        "--det_threshold", help="Detection threshold.", default=0.9, type=float
    )
    parser.add_argument("--images_path", help="Image list, or folder.",
                        default= 'images')
    parser.add_argument("--output_path", help="Path to save predictions",
                        default = 'results/main/')

    args = parser.parse_args()

    if not os.path.exists(args.output_path):
        os.makedirs(args.output_path)

    return args

if __name__ == "__main__":
    args = parse_args()

    img2pose = img2pose(args)
    print('loaded')

    img2pose.align(export_onnx = False, use_onnx_bbox = False,
                   use_main_bbox = False, aligning_faces_main = True,
                   aligning_faces_onnx_new=False,aligning_faces_onnx_old=False)
121649982 commented 5 months ago

have you soloved it ?I encountered this issue too.