joonb14 / TFLiteSegmentation

TensorFlow Lite Segementation example in Python
6 stars 1 forks source link
flask flask-tensorflow image-segmentation image-segmentation-tensorflow image-segmentation-web python python-tflite segment tensorflow2 tflite tflite-model tflite-models tflite-python tflite-segmentation-python web website

TFLite Segmentation Python

This code snipset is heavily based on TensorFlow Lite Segmentation
The segmentation model can be downloaded from above link.
For the mask generation I looked into the Android Segmentation Example
Follow the DeepLabv3.ipynb to get information about how to use the TFLite model in your Python environment.

Details

The lite-model_deeplabv3_1_metadata_2.tflite file's input takes normalized 257x257x3 shape image. And the output is 257x257x21 where the 257x257 denotes the pixel location of image and the last 21 is equal to the labels that the TFLite model can classify.
The order of the 21 classes are

labelsArrays = ["background", "aeroplane", "bicycle", "bird", "boat", "bottle", "bus",  "car", "cat", "chair", "cow", "dining table", "dog", "horse", "motorbike", "person", "potted plant", "sheep", "sofa", "train", "tv"]

For model inference, we need to load, resize, normalize the image.
In my case for convenience used pillow library to load and just applied /255 for all values.
Then if you follow the correct instruction provided by Google in load_and_run_a_model_in_python, you would get output in below shape

Now we need to process this output to a mask like this to segment the class we want.

For this process we need to compare the values in the output

mSegmentBits = np.zeros((257,257)).astype(int)
outputbitmap = np.zeros((257,257)).astype(int)
for y in range(257):
    for x in range(257):
        maxval = 0
        mSegmentBits[x][y]=0

        for c in range(21):
            value = output_data[0][y][x][c]
            if c == 0 or value > maxVal:
                maxVal = value
                mSegmentBits[y][x] = c
#         print(mSegmentBits[x][y])
        label = labelsArrays[mSegmentBits[x][y]]
#         print(label)
        if(mSegmentBits[y][x]==15):
            outputbitmap[y][x]=1
        else:
            outputbitmap[y][x]=0

In the above example, I wanted to segment person only and consider all others as background, I knew that the person's label is 15 in the TFLite model, that's why I used 15 in

if(mSegmentBits[y][x]==15):
    outputbitmap[y][x]=1

I believe you can modify the rest of the code as you want by yourself.
Thank you!