facebookresearch / d2go

D2Go is a toolkit for efficient deep learning
Apache License 2.0
842 stars 201 forks source link

exporting keypoint_rcnn_fbnetv3a_dsmask_C4 pretrained #38

Open sadegh16 opened 3 years ago

sadegh16 commented 3 years ago

Hello all

Could anyone ever export torchscript file for keypoint_rcnn_fbnetv3a_dsmask_C4 pre-trained model using create_d2go.py file? I altered the Wrapper to return "keypoints" beside others ("boxes","scores","labels"). "Keypoints" are in the out[3] in Wrapper. res["scores"] = out[2]

Here is the code I use to export the model is: `

!/usr/bin/env python3

import contextlib
import copy
import os
import unittest
from PIL import Image

import torch
from d2go.export.api import convert_and_export_predictor
from d2go.export.d2_meta_arch import patch_d2_meta_arch
from d2go.runner import create_runner, GeneralizedRCNNRunner
from d2go.model_zoo import model_zoo
from typing import List, Dict
from mobile_cv.common.misc.file_utils import make_temp_directory
from d2go.utils.testing.data_loader_helper import LocalImageGenerator, register_toy_dataset
from d2go.utils.testing.data_loader_helper import create_fake_detection_data_loader

patch_d2_meta_arch()

  cfg_name = 'keypoint_rcnn_fbnetv3a_dsmask_C4.yaml'
  pytorch_model = model_zoo.get(cfg_name, trained=True)
  # pytorch_model.training=False
  # pytorch_model.eval()

  class Wrapper(torch.nn.Module):

      def __init__(self, model):

          super().__init__()
          self.model = model
          coco_idx_list = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25,
                           27, 28, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 46, 47, 48, 49, 50, 51,
                           52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 67, 70, 72, 73, 74, 75, 76, 77,
                           78, 79, 80, 81, 82, 84, 85, 86, 87, 88, 89, 90, 91]

          self.coco_idx = torch.tensor(coco_idx_list)

      def forward(self, inputs: List[torch.Tensor]):
          x = inputs[0].unsqueeze(0) * 255
          scale = 320.0 / min(x.shape[-2], x.shape[-1])
          x = torch.nn.functional.interpolate(x, scale_factor=scale, mode="bilinear", align_corners=True, recompute_scale_factor=True)
          out = self.model(x[0])

          res=( out[3],
          out[0] / scale,
          torch.index_select(self.coco_idx, 0, out[1]),
          out[4],)

          return res

  size_divisibility = max(pytorch_model.backbone.size_divisibility, 10)

  h, w = size_divisibility, size_divisibility * 2
  with create_fake_detection_data_loader(h, w, is_train=False) as data_loader:
      predictor_path = convert_and_export_predictor(
          model_zoo.get_config(cfg_name),
          copy.deepcopy(pytorch_model),
          "torchscript_int8@tracing",
          './',
          data_loader,
      )

      orig_model = torch.jit.load(os.path.join(predictor_path, "model.jit"))
      wrapped_model = Wrapper(orig_model)
      # optionally do a forward
      import cv2
      im = cv2.imread("inp8.jpg",)
      im=torch.tensor(im)/255
      im=torch.reshape(im,(3,im.shape[0],im.shape[1]))

      wrapped_model([im])
      scripted_model = torch.jit.script(wrapped_model)
      scripted_model.save("d2go_tracker_temp.pt")

` when I use exported .pt file in android I get corrupted keypoints coordinates and it seems it's due to the TracerWarnings meanwhile I export torchscript file ( Converting a tensor to other python types cause the value to be constant in the torchscript output file)

I am pretty sure the input format of the forwarding path in android is correct. The output of the model in android for keypoint_rcnn_fbnetv3a_dsmask_C4 model is "boxes","scores","labels","keypoints". but "keypoints" are not correct. others are fine and I can draw boxes around "persons".

my meaning of corrupted keypoints : for each keypoint the model in android returns the same (x,y,probability) ## Expected behavior: get the same output as the time I run the model in python3 with DemoPredicator.

sadegh16 commented 3 years ago

Could anyone please guide me here after updating the code?

XiaoliangDai commented 3 years ago

Hi there, thank you for your feedback! Can you provide the detailed tracer warning? Also, could you try exporting the int8 model.jit with https://github.com/facebookresearch/d2go/tree/master/demo and test the model.jit and wrapper again?

sadegh16 commented 3 years ago

@XiaoliangDai

I got new errors during loading model on Android.

image

This step is even before feeding the model with an image which was possible before.

sadegh16 commented 3 years ago

@XiaoliangDai

By using this and setting all datasets in (annotations in the defined directory) I get:

image

could you help me? even the sample int8 quantization in the above link gives warnings like that.

wat3rBro commented 3 years ago

@sadegh16 Hi is it still an issue? Seems the latest error message is related to data loading (maybe the dataset is not installed correctly), to debug could you disable error handling by appending D2GO_DATA.MAPPER.CATCH_EXCEPTION False to the command?

ashutoshsoni891 commented 3 years ago

getting the same error , any luck @wat3rBro @sadegh16 ?