aws / sagemaker-pytorch-inference-toolkit

Toolkit for allowing inference and serving with PyTorch on SageMaker. Dockerfiles used for building SageMaker Pytorch Containers are at https://github.com/aws/deep-learning-containers.
Apache License 2.0
131 stars 70 forks source link

[Question] Using model.mar with built-in handler script #133

Open austinmw opened 1 year ago

austinmw commented 1 year ago

What did you find confusing? Please describe. Hi, I've recently used a torchserve export utility provided by the MMDetection library. This uses the torch model archiver to package the following files into a model.mar:

Archive:  model.mar
  bbox_mAP_epoch_1.pth
  config.py
  mmdet_handler.py
  MAR-INF/MANIFEST.json

Is it possible to use this model.mar directly with SM TorchServe without needing to pull out the handler script and reformat it?

Additional context I've added mmdet into my requirements.txt so the dependencies are not a problem.

The mmdet_handler.py script looks like this:

# Copyright (c) OpenMMLab. All rights reserved.
import base64
import os

import mmcv
import numpy as np
import torch
from ts.torch_handler.base_handler import BaseHandler

from mmdet.apis import inference_detector, init_detector
from mmdet.utils import register_all_modules

register_all_modules(True)

class MMdetHandler(BaseHandler):
    threshold = 0.5

    def initialize(self, context):
        properties = context.system_properties
        self.map_location = 'cuda' if torch.cuda.is_available() else 'cpu'
        self.device = torch.device(self.map_location + ':' +
                                   str(properties.get('gpu_id')) if torch.cuda.
                                   is_available() else self.map_location)
        self.manifest = context.manifest

        model_dir = properties.get('model_dir')
        serialized_file = self.manifest['model']['serializedFile']
        checkpoint = os.path.join(model_dir, serialized_file)
        self.config_file = os.path.join(model_dir, 'config.py')

        self.model = init_detector(self.config_file, checkpoint, self.device)
        self.initialized = True

    def preprocess(self, data):
        images = []

        for row in data:
            image = row.get('data') or row.get('body')
            if isinstance(image, str):
                image = base64.b64decode(image)
            image = mmcv.imfrombytes(image)
            images.append(image)

        return images

    def inference(self, data, *args, **kwargs):
        results = inference_detector(self.model, data)
        return results

    def postprocess(self, data):
        # Format output following the example ObjectDetectionHandler format
        output = []
        for data_sample in data:
            pred_instances = data_sample.pred_instances
            bboxes = pred_instances.bboxes.cpu().numpy().astype(
                np.float32).tolist()
            labels = pred_instances.labels.cpu().numpy().astype(
                np.int32).tolist()
            scores = pred_instances.scores.cpu().numpy().astype(
                np.float32).tolist()
            preds = []
            for idx in range(len(labels)):
                cls_score, bbox, cls_label = scores[idx], bboxes[idx], labels[
                    idx]
                if cls_score >= self.threshold:
                    class_name = self.model.dataset_meta['CLASSES'][cls_label]
                    result = dict(
                        class_label=cls_label,
                        class_name=class_name,
                        bbox=bbox,
                        score=cls_score)
                    preds.append(result)
            output.append(preds)
        return output