facebookresearch / Detic

Code release for "Detecting Twenty-thousand Classes using Image-level Supervision".
Apache License 2.0
1.86k stars 211 forks source link

Export Detic to ONNX with custom vocabulary #113

Open gigasurgeon opened 10 months ago

gigasurgeon commented 10 months ago

I wanted to share the method to export detic model to ONNX format with custom vocabulary.

Step 1) First of all, comment out this line box_features = _ScaleGradient.apply(box_features, 1.0 / self.num_cascade_stages) in custom_rcnn.py

Step 2) Also, according to this comment https://github.com/facebookresearch/Detic/issues/107#issuecomment-1752039648 , you have to comment the nms_and_topk line in centernet, while exporting the model

boxlists = self.nms_and_topK(boxlists, nms=not self.not_nms)

Step 3) Now on to the main part. You need to modify this file -> Detic/detectron2/tools/deploy/export_model.py

This is the final script I had

#!/usr/bin/env python
# Copyright (c) Facebook, Inc. and its affiliates.
import argparse
import os
from typing import Dict, List, Tuple
import torch
from torch import Tensor, nn

import sys
sys.path.insert(0, '/vmdata/amitsingh/workspace/Detic')
sys.path.insert(0, '/vmdata/amitsingh/workspace/Detic/third_party/CenterNet2')

import detectron2.data.transforms as T
from detectron2.checkpoint import DetectionCheckpointer
from detectron2.config import get_cfg
from detectron2.data import build_detection_test_loader, detection_utils
from detectron2.evaluation import COCOEvaluator, inference_on_dataset, print_csv_format
from detectron2.export import (
    STABLE_ONNX_OPSET_VERSION,
    TracingAdapter,
    dump_torchscript_IR,
    scripting_with_instances,
)
from detectron2.modeling import GeneralizedRCNN, RetinaNet, build_model
from detectron2.modeling.postprocessing import detector_postprocess
from detectron2.projects.point_rend import add_pointrend_config
from detectron2.structures import Boxes
from detectron2.utils.env import TORCH_VERSION
from detectron2.utils.file_io import PathManager
from detectron2.utils.logger import setup_logger
from centernet.config import add_centernet_config
from detic.config import add_detic_config

def setup_cfg(args):
    cfg = get_cfg()
    # cuda context is initialized before creating dataloader, so we don't fork anymore
    cfg.DATALOADER.NUM_WORKERS = 0
    add_pointrend_config(cfg)
    add_centernet_config(cfg)
    add_detic_config(cfg)
    cfg.merge_from_file(args.config_file)
    cfg.merge_from_list(args.opts)
    cfg.freeze()
    return cfg

def export_caffe2_tracing(cfg, torch_model, inputs):
    from detectron2.export import Caffe2Tracer

    tracer = Caffe2Tracer(cfg, torch_model, inputs)
    if args.format == "caffe2":
        caffe2_model = tracer.export_caffe2()
        caffe2_model.save_protobuf(args.output)
        # draw the caffe2 graph
        caffe2_model.save_graph(os.path.join(args.output, "model.svg"), inputs=inputs)
        return caffe2_model
    elif args.format == "onnx":
        import onnx

        onnx_model = tracer.export_onnx()
        onnx.save(onnx_model, os.path.join(args.output, "model.onnx"))
    elif args.format == "torchscript":
        ts_model = tracer.export_torchscript()
        with PathManager.open(os.path.join(args.output, "model.ts"), "wb") as f:
            torch.jit.save(ts_model, f)
        dump_torchscript_IR(ts_model, args.output)

# experimental. API not yet final
def export_scripting(torch_model):
    assert TORCH_VERSION >= (1, 8)
    fields = {
        "proposal_boxes": Boxes,
        "objectness_logits": Tensor,
        "pred_boxes": Boxes,
        "scores": Tensor,
        "pred_classes": Tensor,
        "pred_masks": Tensor,
        "pred_keypoints": torch.Tensor,
        "pred_keypoint_heatmaps": torch.Tensor,
    }
    assert args.format == "torchscript", "Scripting only supports torchscript format."

    class ScriptableAdapterBase(nn.Module):
        # Use this adapter to workaround https://github.com/pytorch/pytorch/issues/46944
        # by not retuning instances but dicts. Otherwise the exported model is not deployable
        def __init__(self):
            super().__init__()
            self.model = torch_model
            self.eval()

    if isinstance(torch_model, GeneralizedRCNN):

        class ScriptableAdapter(ScriptableAdapterBase):
            def forward(self, inputs: Tuple[Dict[str, torch.Tensor]]) -> List[Dict[str, Tensor]]:
                instances = self.model.inference(inputs, do_postprocess=False)
                return [i.get_fields() for i in instances]

    else:

        class ScriptableAdapter(ScriptableAdapterBase):
            def forward(self, inputs: Tuple[Dict[str, torch.Tensor]]) -> List[Dict[str, Tensor]]:
                instances = self.model(inputs)
                return [i.get_fields() for i in instances]

    ts_model = scripting_with_instances(ScriptableAdapter(), fields)
    with PathManager.open(os.path.join(args.output, "model.ts"), "wb") as f:
        torch.jit.save(ts_model, f)
    dump_torchscript_IR(ts_model, args.output)
    # TODO inference in Python now missing postprocessing glue code
    return None

# experimental. API not yet final
def export_tracing(torch_model, inputs):
    assert TORCH_VERSION >= (1, 8)
    image = inputs[0]["image"]
    inputs = [{"image": image}]  # remove other unused keys

    if isinstance(torch_model, GeneralizedRCNN):

        def inference(model, inputs):
            # use do_postprocess=False so it returns ROI mask
            inst = model.inference(inputs, do_postprocess=False)[0]
            return [{"instances": inst}]

    else:
        inference = None  # assume that we just call the model directly

    traceable_model = TracingAdapter(torch_model, inputs, inference)

    if args.format == "torchscript":
        ts_model = torch.jit.trace(traceable_model, (image,))
        with PathManager.open(os.path.join(args.output, "model.ts"), "wb") as f:
            torch.jit.save(ts_model, f)
        dump_torchscript_IR(ts_model, args.output)
    elif args.format == "onnx":
        with PathManager.open(os.path.join(args.output, "model.onnx"), "wb") as f:
            torch.onnx.export(traceable_model, (image,), f, opset_version=STABLE_ONNX_OPSET_VERSION)
    logger.info("Inputs schema: " + str(traceable_model.inputs_schema))
    logger.info("Outputs schema: " + str(traceable_model.outputs_schema))

    if args.format != "torchscript":
        return None
    if not isinstance(torch_model, (GeneralizedRCNN, RetinaNet)):
        return None

    def eval_wrapper(inputs):
        """
        The exported model does not contain the final resize step, which is typically
        unused in deployment but needed for evaluation. We add it manually here.
        """
        input = inputs[0]
        instances = traceable_model.outputs_schema(ts_model(input["image"]))[0]["instances"]
        postprocessed = detector_postprocess(instances, input["height"], input["width"])
        return [{"instances": postprocessed}]

    return eval_wrapper

def get_sample_inputs(args):

    if args.sample_image is None:
        # get a first batch from dataset
        data_loader = build_detection_test_loader(cfg, cfg.DATASETS.TEST[0])
        first_batch = next(iter(data_loader))
        return first_batch
    else:
        # get a sample data
        original_image = detection_utils.read_image(args.sample_image, format=cfg.INPUT.FORMAT)
        # Do same preprocessing as DefaultPredictor
        aug = T.ResizeShortestEdge(
            [cfg.INPUT.MIN_SIZE_TEST, cfg.INPUT.MIN_SIZE_TEST], cfg.INPUT.MAX_SIZE_TEST
        )
        height, width = original_image.shape[:2]
        image = aug.get_transform(original_image).apply_image(original_image)
        image = torch.as_tensor(image.astype("float32").transpose(2, 0, 1))

        inputs = {"image": image, "height": height, "width": width}

        # Sample ready
        sample_inputs = [inputs]
        return sample_inputs

def get_clip_embeddings(vocabulary, prompt='a '):
    from detic.modeling.text.text_encoder import build_text_encoder
    text_encoder = build_text_encoder(pretrain=True)
    text_encoder.eval()
    texts = [prompt + x for x in vocabulary]
    emb = text_encoder(texts).detach().permute(1, 0).contiguous().cpu()
    return emb

def reset_cls_test(model, cls_path, num_classes):
    import numpy as np
    from torch.nn import functional as F

    model.roi_heads.num_classes = num_classes
    if type(cls_path) == str:
        print('Resetting zs_weight', cls_path)
        zs_weight = torch.tensor(
            np.load(cls_path),
            dtype=torch.float32).permute(1, 0).contiguous() # D x C
    else:
        zs_weight = cls_path
    zs_weight = torch.cat(
        [zs_weight, zs_weight.new_zeros((zs_weight.shape[0], 1))],
        dim=1) # D x (C + 1)
    if model.roi_heads.box_predictor[0].cls_score.norm_weight:
        zs_weight = F.normalize(zs_weight, p=2, dim=0)
    zs_weight = zs_weight.to(model.device)
    for k in range(len(model.roi_heads.box_predictor)):
        del model.roi_heads.box_predictor[k].cls_score.zs_weight
        model.roi_heads.box_predictor[k].cls_score.zs_weight = zs_weight

if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Export a model for deployment.")
    parser.add_argument(
        "--format",
        choices=["caffe2", "onnx", "torchscript"],
        help="output format",
        default="torchscript",
    )
    parser.add_argument(
        "--export-method",
        choices=["caffe2_tracing", "tracing", "scripting"],
        help="Method to export models",
        default="tracing",
    )
    parser.add_argument("--config-file", default="", metavar="FILE", help="path to config file")
    parser.add_argument("--sample-image", default=None, type=str, help="sample image for input")
    parser.add_argument("--run-eval", action="store_true")
    parser.add_argument("--output", help="output directory for the converted model")
    parser.add_argument(
        "opts",
        help="Modify config options using the command-line",
        default=None,
        nargs=argparse.REMAINDER,
    )
    args = parser.parse_args()
    logger = setup_logger()
    logger.info("Command line arguments: " + str(args))
    PathManager.mkdirs(args.output)
    # Disable re-specialization on new shapes. Otherwise --run-eval will be slow
    torch._C._jit_set_bailout_depth(1)

    cfg = setup_cfg(args)

    # create a torch model with custom_classes
    custom_classes = ['scoop', 'teaspoon', 'spoon', 'tea_spoon', 'flatware', 'tong', 'coffee_spoon', 'soupspoon', 'soup_spoon', 'spatula', 'ladle', 'skimmer', 'bowl', 'egg_bowl', 'sugar_bowl', 'washing_bowl', 'salad_bowl', 'cereal_bowl', 'soup_bowl', 'saucepan', 'frying_pan', 'pan', 'cake_pan', 'sauce_pan', 'content_pan', 'wok', 'saucer', 'plate', 'chinaware', 'glass', 'wine_glass', 'chalice', 'dixie_cup', 'flute_glass', 'shot_glass', 'wineglass', 'milk_bottle', 'bottle', 'water_bottle', 'wine_bottle', 'beer_bottle', 'tea_pot', 'pot', 'pressure_pot', 'pasta_pot', 'plastic_pot', 'sauce_pot', 'teapot', 'crock_pot', 'crockpot', 'cup', 'measuring_cup', 'coffee_cup', 'mug', 'teacup', 'tea_cup', 'pitcher', 'coffee_jar', 'sugar_jar', 'honey_jar', 'jar', 'jug', 'coffeepot', 'kettle', 'water_jug', 'urn', 'cream_pitcher', 'coffee_pot', 'container', 'lunch_box', 'sugar_container', 'milk_container', 'rice_container', 'sauce_container', 'food_container', 'casserole', 'knife', 'steak_knife', 'knife_sharpener', 'lime_squeezer', 'peeler', 'grater', 'skimmer', 'cheese_grater', 'masher', 'squeezer', 'potato_peeler', 'lime_juicer', 'scissor', 'tray', 'baking_tray', 'pizza_tray', 'baking_pan', 'serving_board', 'eating_board', 'chopping_board', 'cut_board', 'cutting_board', 'board', 'pasta_strainer', 'strainer', 'mesh_strainer', 'can', 'beer_can', 'milk_can', 'canister', 'wine_bucket', 'bucket', 'plastic_bucket']
    num_classes = len(custom_classes)
    classifier = get_clip_embeddings(custom_classes)

    torch_model = build_model(cfg)
    DetectionCheckpointer(torch_model).resume_or_load(cfg.MODEL.WEIGHTS)
    torch_model.eval()
    print('huihui', torch_model.roi_heads.num_classes)
    reset_cls_test(torch_model, classifier, num_classes)
    # print('huihuii2', torch_model.roi_heads.num_classes)
    # exit()
    # convert and save model
    if args.export_method == "caffe2_tracing":
        sample_inputs = get_sample_inputs(args)
        exported_model = export_caffe2_tracing(cfg, torch_model, sample_inputs)
    elif args.export_method == "scripting":
        exported_model = export_scripting(torch_model)
    elif args.export_method == "tracing":
        sample_inputs = get_sample_inputs(args)
        exported_model = export_tracing(torch_model, sample_inputs)

    # run evaluation with the converted model
    if args.run_eval:
        assert exported_model is not None, (
            "Python inference is not yet implemented for "
            f"export_method={args.export_method}, format={args.format}."
        )
        logger.info("Running evaluation ... this takes a long time if you export to CPU.")
        dataset = cfg.DATASETS.TEST[0]
        data_loader = build_detection_test_loader(cfg, dataset)
        # NOTE: hard-coded evaluator. change to the evaluator for your dataset
        evaluator = COCOEvaluator(dataset, output_dir=args.output)
        metrics = inference_on_dataset(exported_model, data_loader, evaluator)
        print_csv_format(metrics)
    logger.info("Success.")

At line 253 custom_classes = ['scoop', .... is where I have added my custom labels.

Step 4) Now you need to execute this script with the command python3 detectron2/tools/deploy/export_model_lvis_vocabulary.py --config-file configs/Detic_LCOCOI21k_CLIP_SwinB_896b32_4x_ft4x_max-size.yaml --sample-image desk.jpg --output ./output --export-method tracing --format onnx MODEL.WEIGHTS models/Detic_LCOCOI21k_CLIP_SwinB_896b32_4x_ft4x_max-size.pth MODEL.DEVICE cuda from detic's root folder. This will save the ONNX model at output/model.onnx.

gigasurgeon commented 10 months ago

And to infer from the ONNX model, I am using @HtwoOtwo's script from this comment ->https://github.com/facebookresearch/Detic/issues/107#issuecomment-1803014546

The slightly modified inference script looks like this ->

import argparse
import cv2
import numpy as np
import onnxruntime as ort
import time

class Detic():
    def __init__(self, modelpath, detection_width=800, confThreshold=0.8):
        providers = ['CUDAExecutionProvider']
        self.session = ort.InferenceSession(modelpath, providers=providers)
        model_inputs = self.session.get_inputs()
        self.input_name = model_inputs[0].name
        self.max_size = detection_width
        self.confThreshold = confThreshold
        self.class_names = ['scoop', 'teaspoon', 'spoon', 'tea_spoon', 'flatware', 'tong', 'coffee_spoon', 'soupspoon', 'soup_spoon', 'spatula',
                            'ladle', 'skimmer', 'bowl', 'egg_bowl', 'sugar_bowl', 'washing_bowl', 'salad_bowl', 'cereal_bowl', 'soup_bowl', 'saucepan',
                            'frying_pan', 'pan', 'cake_pan', 'sauce_pan', 'content_pan', 'wok', 'saucer', 'plate', 'chinaware', 'glass', 'wine_glass',
                            'chalice', 'dixie_cup', 'flute_glass', 'shot_glass', 'wineglass', 'milk_bottle', 'bottle', 'water_bottle', 'wine_bottle',
                            'beer_bottle', 'tea_pot', 'pot', 'pressure_pot', 'pasta_pot', 'plastic_pot', 'sauce_pot', 'teapot', 'crock_pot', 'crockpot',
                            'cup', 'measuring_cup', 'coffee_cup', 'mug', 'teacup', 'tea_cup', 'pitcher', 'coffee_jar', 'sugar_jar', 'honey_jar', 'jar',
                            'jug', 'coffeepot', 'kettle', 'water_jug', 'urn', 'cream_pitcher', 'coffee_pot', 'container', 'lunch_box', 'sugar_container',
                            'milk_container', 'rice_container', 'sauce_container', 'food_container', 'casserole', 'knife', 'steak_knife', 'knife_sharpener',
                            'lime_squeezer', 'peeler', 'grater', 'skimmer', 'cheese_grater', 'masher', 'squeezer', 'potato_peeler', 'lime_juicer', 'scissor',
                            'tray', 'baking_tray', 'pizza_tray', 'baking_pan', 'serving_board', 'eating_board', 'chopping_board', 'cut_board', 'cutting_board',
                            'board', 'pasta_strainer', 'strainer', 'mesh_strainer', 'can', 'beer_can', 'milk_can', 'canister', 'wine_bucket', 'bucket', 'plastic_bucket']

        # self.assigned_colors = np.random.randint(0,high=256, size=(len(self.class_names), 3)).tolist()
        self.assigned_colors = np.random.randint(0,high=256, size=(4, 3)).tolist()

    def preprocess(self, srcimg):
        im_h, im_w, _ = srcimg.shape
        dstimg = cv2.cvtColor(srcimg, cv2.COLOR_BGR2RGB)
        if im_h < im_w:
            scale = self.max_size / im_h
            oh, ow = self.max_size, scale * im_w
        else:
            scale = self.max_size / im_w
            oh, ow = scale * im_h, self.max_size

        max_hw = max(oh, ow)
        if max_hw > self.max_size:
            scale = self.max_size / max_hw
            oh *= scale
            ow *= scale
        ow = int(ow + 0.5)
        oh = int(oh + 0.5)
        dstimg = cv2.resize(dstimg, (1067, 800))
        return dstimg

    def suppress_overlapping_bboxes(self, pred_boxes, scores, pred_classes, pred_masks):
        pred_boxes = pred_boxes.astype(np.int64)

        coord_str_dict = {}

        for i in range(pred_boxes.shape[0]):
            coord_str = f'{pred_boxes[i][0]}_{pred_boxes[i][1]}_{pred_boxes[i][2]}_{pred_boxes[i][3]}'

            if coord_str not in coord_str_dict:
                coord_str_dict[coord_str] = i
            else:
                if scores[i]>coord_str_dict[coord_str]:
                    coord_str_dict[coord_str] = i

        pred_boxes = np.array([pred_boxes[coord_str_dict[coord_str]] for coord_str in coord_str_dict])
        scores = np.array([scores[coord_str_dict[coord_str]] for coord_str in coord_str_dict])
        pred_classes = np.array([pred_classes[coord_str_dict[coord_str]] for coord_str in coord_str_dict])
        pred_masks = np.array([pred_masks[coord_str_dict[coord_str]] for coord_str in coord_str_dict])

        return pred_boxes, scores, pred_classes, pred_masks

    def post_processing(self, pred_boxes, scores, pred_classes, pred_masks, im_hw, pred_hw):
        scale_x, scale_y = (im_hw[1] / pred_hw[1], im_hw[0] / pred_hw[0])

        pred_boxes[:, 0::2] *= scale_x
        pred_boxes[:, 1::2] *= scale_y
        pred_boxes[:, [0, 2]] = np.clip(pred_boxes[:, [0, 2]], 0, im_hw[1])
        pred_boxes[:, [1, 3]] = np.clip(pred_boxes[:, [1, 3]], 0, im_hw[0])

        threshold = 0
        widths = pred_boxes[:, 2] - pred_boxes[:, 0]
        heights = pred_boxes[:, 3] - pred_boxes[:, 1]
        keep = (widths > threshold) & (heights > threshold)

        pred_boxes = pred_boxes[keep]
        scores = scores[keep]
        pred_classes = pred_classes[keep]
        pred_masks = pred_masks[keep]

        # mask_threshold = 0.5
        # pred_masks = paste_masks_in_image(
        #     pred_masks[:, 0, :, :], pred_boxes,
        #     (im_hw[0], im_hw[1]), mask_threshold
        # )
        threshold = 0.5
        idx = scores>threshold
        scores = scores[idx]
        pred_boxes = pred_boxes[idx]
        pred_classes = pred_classes[idx]
        pred_masks = pred_masks[idx]

        pred_boxes, scores, pred_classes, pred_masks = self.suppress_overlapping_bboxes(pred_boxes, scores, pred_classes, pred_masks)

        pred = {
            'pred_boxes': pred_boxes,
            'scores': scores,
            'pred_classes': pred_classes,
            'pred_masks': pred_masks,
        }

        # print(pred)
        # exit()
        return pred

    def draw_predictions(self, img, predictions):
        height, width = img.shape[:2]
        default_font_size = int(max(np.sqrt(height * width) // 90, 10))
        boxes = predictions["pred_boxes"].astype(np.int64)
        scores = predictions["scores"]
        # print(predictions["pred_classes"])
        # exit()
        classes_id = predictions["pred_classes"].tolist()
        # masks = predictions["pred_masks"].astype(np.uint8)
        num_instances = len(boxes)
        print('detect', num_instances, 'instances')

        for i in range(num_instances):
            x0, y0, x1, y1 = boxes[i]
            # color = self.assigned_colors[classes_id[i]]
            color = [0,255,0]
            cv2.rectangle(img, (x0, y0), (x1, y1), color=color,thickness=default_font_size // 4)
            # text = "{} {:.0f}%".format(self.class_names[classes_id[i]], round(scores[i],2) * 100)
            text = f"{x0}_{y0}_{x1}_{y1} {round(scores[i],2)} {self.class_names[classes_id[i]]}"
            print(text)
            cv2.putText(img, text, (x0, y0 - 5), cv2.FONT_HERSHEY_SIMPLEX, 0.5, color, thickness=1, lineType=cv2.LINE_AA)
        return img

    def detect(self, srcimg):
        im_h, im_w = srcimg.shape[:2]
        dstimg = self.preprocess(srcimg)
        pred_hw = dstimg.shape[:2]
        input_image = dstimg.transpose(2, 0, 1).astype(np.float32)
        # input_image = np.expand_dims(dstimg.transpose(2, 0, 1), axis=0).astype(np.float32)

        # Inference
        pred_boxes, pred_classes, pred_masks, scores, _ = self.session.run(None, {self.input_name: input_image})
        # print(len(scores))
        # exit()
        preds = self.post_processing(pred_boxes, scores, pred_classes, pred_masks, (im_h, im_w), pred_hw)
        return preds

if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument("--imgpath", default="desk.jpg", type=str, help="image path")
    parser.add_argument("--confThreshold", default=0.5, type=float, help='class confidence')
    parser.add_argument("--modelpath", type=str, default='onnx_models/model_custom_vocabulary.onnx', help="onnxmodel path")
    args = parser.parse_args()

    mynet = Detic(args.modelpath, confThreshold=args.confThreshold)
    srcimg = cv2.imread(args.imgpath)

    fpses = []

    for i in range(1):
        print(i)
        t1 = time.time()
        preds = mynet.detect(srcimg)
        t2 = time.time()
        fps = 1/(t2-t1)
        fpses.append(fps)
    avg_fps = sum(fpses)/len(fpses)
    print(f'avg_fps: {round(avg_fps, 2)}')
    result = mynet.draw_predictions(srcimg, preds)

    cv2.imwrite('result_onnx.jpg', result)
antoniodecinque99 commented 9 months ago

Hello @gigasurgeon, thanks for the tutorial. Would you be able to upload directly the onnx file you produced with the script? Thank you so much

gigasurgeon commented 9 months ago

Hello @gigasurgeon, thanks for the tutorial. Would you be able to upload directly the onnx file you produced with the script? Thank you so much

Here's the ONNX file -> https://drive.google.com/file/d/1hYz19lZk4ugLrUGO0HIP9M2RbXs5A4O-/view?usp=sharing