alibaba / TinyNeuralNetwork

TinyNeuralNetwork is an efficient and easy-to-use deep learning model compression framework.
MIT License
735 stars 118 forks source link

How To Add NMS in pytorch model so that it gets converted into TFLITE #16

Open gj-raza opened 2 years ago

gj-raza commented 2 years ago

First off, thanks for this amazing repo. Im working on a ssd model in pytorch and i want to add post processing (NMS) into the tflite, how can i add it into my model so that it gets translated to tflite's NMS OP. thanks

peterjc123 commented 2 years ago

@gj-raza This is currently unsupported. We will take a look later. This shouldn't be too difficult I guess.

peterjc123 commented 2 years ago

@gj-raza Is torchvision.ops.nms sufficient for your usage?

gj-raza commented 2 years ago

JFYI, I've added the TFlite custom post processing OP in a keras object detection model following this method, but since pytorch has no NMS layer, so its getting tricky here.

Also i've noticed, TFLite's NMS v4 and v5, exisits in your code here , can you please explain what is the overall flow of tinynn converter. so that i can contribute this feature myself

gj-raza commented 2 years ago

@gj-raza Is torchvision.ops.nms sufficient for your usage?

i've tried this but it seems to not work. have a look at below code

# Simple Pytorch model
class PyModel(nn.Module):
    def __init__(self):
        super(PyModel, self).__init__()
        self.conv1 = nn.Conv2d(3, 64, (3,3), 1, 1)
        self.conv2 = nn.Conv2d(64, 64, (3,3), 1,1)

    def forward(self, x):
        out = self.conv1(x)
        out = self.conv2(out)
        out = torch.reshape(out,(640000,4))
        out1 = out [:,0]
        out3 = torchvision.ops.nms(out, out1, 0.5)
        return out3

def main_worker(args):
    print("###### TinyNeuralNetwork quick start for beginner ######")
    torch.cuda.empty_cache() 
    model = PyModel()

    device = get_device()
    model.to(device=device)

    # Provide a viable input for the model
    dummy_input = torch.rand((1, 3, 200, 200))

    context = DLContext()
    context.device = device
    context.train_loader, context.val_loader = get_dataloader(args.data_path, 220, args.batch_size, args.workers)

    print("\n###### Start preparing the model for quantization ######")
    # We provides a QATQuantizer class that may rewrite the graph for and perform model fusion for quantization
    # The model returned by the `quantize` function is ready for QAT training
    quantizer = QATQuantizer(model, dummy_input, work_dir='out')
    qat_model = quantizer.quantize()

    print("\n###### Start converting the model to TFLite ######")
    with torch.no_grad():
        qat_model.eval()
        qat_model.cpu()

        # The step below converts the model to an actual quantized model, which uses the quantized kernels.
        qat_model = torch.quantization.convert(qat_model)
        print (type(qat_model))
        # When converting quantized models to TFLite, please ensure the quantization backend is QNNPACK.
        torch.backends.quantized.engine = 'qnnpack'

        # The code section below is used to convert the model to the TFLite format
        converter = TFLiteConverter(qat_model, dummy_input, tflite_path='out/qat_model_small.tflite')
        converter.convert()

if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('--data-path', metavar='DIR', default="/data/datasets/cifar10", help='path to cifar10 dataset')
    parser.add_argument('--workers', type=int, default=8)
    parser.add_argument('--batch-size', type=int, default=256)

    args = parser.parse_args()
    main_worker(args)

it gives following error,

TinyNeuralNetwork quick start for beginner
Start preparing the model for quantization

Traceback (most recent call last): File "/home/gj/hazen/alibaba_tinynn/TinyNeuralNetwork/tinynn/graph/tracer.py", line 2061, in trace new_graph.init() File "/home/gj/hazen/alibaba_tinynn/TinyNeuralNetwork/tinynn/graph/tracer.py", line 1317, in init self.module(*actual_input) File "/home/gj/anaconda3/envs/pyt-tinynn/lib/python3.6/site-packages/torch/nn/modules/module.py", line 726, in _call_impl hook_result = hook(self, input, result) File "/home/gj/hazen/alibaba_tinynn/TinyNeuralNetwork/tinynn/graph/tracer.py", line 1110, in _model_tracer add_output_node(node, outputs) File "/home/gj/hazen/alibaba_tinynn/TinyNeuralNetwork/tinynn/graph/tracer.py", line 982, in add_output_node node.prev_nodes.append(current_graph().nodes_map[current_graph().tensor_pre_node_dict[id(t)]]) KeyError: 140563279179208 ERROR (tinynn.graph.tracer) inputs: ['input_1'] ERROR (tinynn.graph.tracer) forwards: ['conv1', 'conv2', 'reshape_1', 'getitem_1'] ERROR (tinynn.graph.tracer) outputs: [] ERROR (tinynn.graph.tracer) constants: []

zhiqwang commented 2 years ago

Hi, The torchvision::ops::batched_nms in PyTorch looks like the tensorflow::ops::combined-non-max-suppression (set the q to 1) in Tensorflow (I'm not sure about this numerical equivalence and I should do some verification about this), that's much faster than torchvision.ops.nms. Is it possible to implement this transformation as well?

peterjc123 commented 2 years ago

Also i've noticed, TFLite's NMS v4 and v5, exisits in your code here , can you please explain what is the overall flow of tinynn converter. so that i can contribute this feature myself

It is a long story. Currently, you have to make changes to multiple components to make it work. Let me list things to do here.

::batched_nms in PyTorch looks like the tensorflow::ops::combined-non-max-suppression (set the q to 1) in Tensorflow (I'm not sure about this numerical equivalence and I should do some verification about this), that's much faster than torchvision.ops.nms. Is it possible to implement this transformation as well?

Yeah, you are right. However, according to this post, combined-non-max-suppression translates to a Flex op in TFLite. Currently, supporting the Flex op is a low-priority work for us. Patches are welcome.

peterjc123 commented 2 years ago

@gj-raza I've done the first two tasks. If you are interested, you may take a look at the latter two, which should be fairly easy.

gj-raza commented 2 years ago

@gj-raza I've done the first two tasks. If you are interested, you may take a look at the latter two, which should be fairly easy.

@peterjc123 Sure, but I have not previously worked with Pytorch/TFL internals or schemas so it might take me some time figuring it all out on my own, so if there is any documentation, links, tutorials etc that you think will be helpful please share, it'll get on it asap

peterjc123 commented 2 years ago

@gj-raza As for the PyTorch side, the schema is quite clear. https://github.com/alibaba/TinyNeuralNetwork/blob/main/tinynn/converter/operators/torch/torchvision_schema.py#L40 https://pytorch.org/vision/stable/ops.html#torchvision.ops.nms With regard to the TFLite side, you may refer to the following docs. https://www.tensorflow.org/mlir/tfl_ops#tflnon_max_suppression_v4_mlirtflnonmaxsuppressionv4op https://www.tensorflow.org/api_docs/python/tf/image/non_max_suppression_padded https://github.com/tensorflow/tensorflow/blob/master/tensorflow/lite/kernels/non_max_suppression.cc#L29-L60

To implement NMS translation, you may have to do the following things:

  1. Create a new file torchvision.py under https://github.com/alibaba/TinyNeuralNetwork/blob/main/tinynn/converter/operators/torch/ and import some libraries just as in aten.py.
  2. Create a skeleton class TorchVisionNmsOperator

    class TorchVisionNmsOperator(TorchVisionNmsSchema):
    def parse(self, node, attrs, args, graph_converter):
        super().parse(node, attrs, args, graph_converter)
    
        self.run(node)
  3. In the parse function, you may get the input/output arguments or tensors through self.input_tensor and self.output_tensor respectively. We also provide self.find_or_create_input and self.to_tfl_tensors for converting those tensors to the TFL format. If you need TFL tensors other than the I/O tensors, there's also self.create_attr_tensor and self.create_transform_tensor for creating tensors that serves as constants and variables.
  4. As for OP creation at the TFLite side, you just need one line to do that, in which the inputs and outputs are lists of tfl.Tensors.
        graph_converter.add_operator(tfl.NonMaxSuppressionV4Operator(inputs, outputs))
  5. Write the translation logic (self.input_tensor, self.output_tensor -> inputs, outputs)
  6. Register the operator translator here.

The tricky parts include:

  1. NMS returns a tensor of dynamic size, so you need to pad the PT tensors to a maximum size since TFLite doesn't support dynamic size.
  2. tfl.NonMaxSuppressionV4Operator provides an additional argument score_threshold, which you may need to set it to the default value float(-inf).
  3. The format of the bounding box seems different for both backends. TF: [y1, x1, y2, x2] PT: (x1, y1, x2, y2), which you may need to reorder them using tfl.GatherND.

Plus: There's another OP TFLite_Detection_PostProcess, which is a class-aware alternative. https://github.com/tensorflow/tensorflow/blob/master/tensorflow/lite/kernels/detection_postprocess.cc

It also has the benefit of including the u8 implementation. But the problem is that it accepts different format of boxes (cx, cy, w, h)(could be converted using torchvision.ops.box_convert, which uses torch.stack as the underlying call). Also, its schema is not included in TinyNerualNetwork yet. Needs some modification in the following files: https://github.com/alibaba/TinyNeuralNetwork/blob/main/tinynn/converter/operators/tflite/base.py#L15 https://github.com/alibaba/TinyNeuralNetwork/blob/main/tinynn/converter/operators/tflite/custom.py

peterjc123 commented 2 years ago

Hi, The torchvision::ops::batched_nms in PyTorch looks like the tensorflow::ops::combined-non-max-suppression (set the q to 1) in Tensorflow (I'm not sure about this numerical equivalence and I should do some verification about this), that's much faster than torchvision.ops.nms. Is it possible to implement this transformation as well?

@zhiqwang To my surprise, torchvision.ops.batched_nms is implemented by doing some transforms before calling torchvision.ops.nms here. So actually, it could also be supported.

zhiqwang commented 2 years ago

To my surprise, torchvision.ops.batched_nms is implemented by doing some transforms before calling torchvision.ops.nms here. So actually, it could also be supported.

Yep @peterjc123 , and the name _batched_nms_coordinate_trick tells us everything about the trick between agnostic nms and batched nms.

zhiqwang commented 2 years ago

FYI @peterjc123 I guess the following figure can explain the secret here

Copyright of this figure: https://github.com/ultralytics/yolov5/discussions/5825#discussioncomment-1717311 .

peterjc123 commented 2 years ago

@gj-raza The schema of the TFLITE_DETECTION_POSTPROCESS op is added here in case you may need that.

gj-raza commented 2 years ago

@gj-raza The schema of the TFLITE_DETECTION_POSTPROCESS op is added here in case you may need that.

Thanks @peterjc123 . So now i'll have to map this operator to a new created pytorch operator in step 6 mentioned above?

peterjc123 commented 2 years ago

@gj-raza Yes, you'll need to do steps 1-6 as I described.

peterjc123 commented 2 years ago

Not going to implement it at our side because it is rarely used in our scenarios. However, if you have any questions implementing this feature, you are free to ask the questions. @gj-raza