fanq15 / FewX

FewX is an open-source toolbox on top of Detectron2 for data-limited instance-level recognition tasks.
https://github.com/fanq15/FewX
MIT License
346 stars 48 forks source link

Low performances/Wrong boxes fix #71

Open andrearosasco opened 2 years ago

andrearosasco commented 2 years ago

I don't know how many others encountered the same issue but it took me a while to get correct predictions from the model. Anyway, I'm living this here hoping it can help someone.

So, I wanted to create some code to add my support set (not from COCO or PASCAL) and test the performances on some query images. For some reason, the final bounding boxes were misplaced entirely with respect to the instance. Later I found out that the correct points were lying on the line, parallel to the secondary diagonal, passing through them, mirrored with respect to the main diagonal. Now, I am fairly sure I caused the problem myself by not processing the input correctly but I patched it geometrically and now it works.

Here's the full code:

import cv2
import torch
from detectron2.checkpoint import DetectionCheckpointer
from detectron2.engine import default_setup, default_argument_parser, launch
from detectron2.utils import comm
from detectron2.utils.logger import setup_logger

from fewx.config import get_cfg
from fsod_train_net import Trainer

def setup(args):
    """
    Create configs and perform basic setups.
    """
    cfg = get_cfg()
    cfg.merge_from_file(args.config_file)
    cfg.merge_from_list(args.opts)
    cfg.freeze()
    default_setup(cfg, args)

    rank = comm.get_rank()
    setup_logger(cfg.OUTPUT_DIR, distributed_rank=rank, name="fewx")

    return cfg

def main(args):
    cfg = setup(args)

    model = Trainer.build_model(cfg)
    model.eval()

    DetectionCheckpointer(model, save_dir=cfg.OUTPUT_DIR).resume_or_load(
        cfg.MODEL.WEIGHTS, resume=args.resume
    )
    # dataloader = Trainer.build_test_loader(cfg, "coco_2017_val")
    file_name = 'assets/query.jpg'
    query = cv2.imread(file_name)
    query = torch.tensor(
            cv2.resize(cv2.imread(file_name), (int(600 * (query.shape[1] / query.shape[0])), 600)).copy()
    ).permute(2, 0, 1)

    in_data = [{
        'file_name': file_name,
        'height': query.shape[2],
        'width': query.shape[1],
        'image_id': 0,
        'image': query,
    }]

    out = model(in_data)
    boxes = out[0]['instances'].pred_boxes.tensor

    scores = out[0]['instances'].scores
    boxes = boxes[scores > 0.8]
    boxes[:, [0, 2]], boxes[:, [1, 3]] = boxes[:, [1, 3]], boxes[:, [0, 2]]

    boxes = correct_box_coordinates(query.shape[1], query.shape[2], boxes.reshape(-1, 2).tolist())
    boxes = torch.tensor(boxes).reshape(-1, 4)

    image = in_data[0]['image'].permute(1, 2, 0).numpy()

    for b in boxes:
        image = cv2.rectangle(image.copy(), tuple(b[[0, 1]].to(int).tolist()), tuple(b[[2, 3]].to(int).tolist()), color=(255, 0, 0), thickness=5)

    cv2.imshow('out', image)
    cv2.waitKey(0)

def correct_box_coordinates(height, width, points):
    res = []

    width = -width
    for x_0, y_0 in points:
        y_0 = -y_0

        q = (height / width) * x_0 + y_0

        x_c = (width * q) / (2 * height)
        y_c = q / 2

        x_p = 2 * x_c - x_0
        y_p = 2 * y_c - y_0

        res += [x_p, -y_p]

    return res

if __name__ == "__main__":
    args = default_argument_parser().parse_args()
    print("Command Line Args:", args)
    launch(
        main,
        args.num_gpus,
        num_machines=args.num_machines,
        machine_rank=args.machine_rank,
        dist_url=args.dist_url,
        args=(args,),
    )

And this is the test image I was using: "https://media.ktoo.org/2013/10/Brown-Bears.jpg"