tensorflow / tfjs

A WebGL accelerated JavaScript library for training and deploying ML models.
https://js.tensorflow.org
Apache License 2.0
18.36k stars 1.92k forks source link

Google Meet background segmentation model #4177

Closed jameshfisher closed 2 years ago

jameshfisher commented 3 years ago

System information

Describe the feature and the current behavior/state. This Google AI blog post describes the background segmentation model used in Google Meet. This model would be an excellent complement to the models in the tfjs-models collection. (The existing BodyPix model can be (ab)used for background segmentation, but has quality and performance issues for this use-case. I expect the Google Meet model improves on this.)

Will this change the current api? How? No, it would be an addition to tfjs-models.

Who will benefit with this feature? Apps consuming and/or displaying a user-facing camera feed. WebRTC video chat apps are the most obvious, where background blur/replacement is becoming expected. I also expect it could be a useful preprocessing step before applying e.g. PoseNet. It can also be used creatively on images as a pre-processing step -- for example, this recent app to enhance profile pictures integrates a background segmentation solution.

rthadur commented 3 years ago

cc @annxingyuan @tafsiri

simon-lanf commented 3 years ago

this would be useful for us.

tafsiri commented 3 years ago

I'll pass this on to our PM.

jameshfisher commented 3 years ago

Note: I'd also be happy if just the raw model (https://meet.google.com/_/rtcvidproc/release/336842817/segm_lite_v509.tflite) was released under a permissive license - I can figure out the model structure and JavaScript wiring :-)

jasonmayes commented 3 years ago

+1 to this! Would love to see this as part of the model repos for TFJS - a lot of people making Chrome Extensions to do great things in video calls etc and this would just make those experiences even more efficient when running to get higher FPS etc.

alvaroschipper commented 3 years ago

+1 to this, would be a great, faster alternative to body-pix, really impressed by the performance in Google Meet :)

kirawi commented 3 years ago

Very desirable to have! Though I did just link to this issue from the Jitsi Meets repository, I think it would be very cool to have for other projects that need this functionality but don't have the capabilities to develop an in-house model.

jameshfisher commented 3 years ago

The blog post about this model links to this Model Card describing the model, which reads

LICENSED UNDER Apache License, Version 2.0

The Model Card also links to this paper describing Model Cards in general, which says that Model Cards can describe a license that the model is released under. So I believe the above license applies to the described model itself (e.g. rather than to the Model Card document).

So it seems like the raw .tflite model here is already Apache-licensed! @jasonmayes would you agree with this / is this Google's position?

(Thanks to @blaueente for originally noting this license in the Model Card!)

stanhrivnak commented 3 years ago

Note: I'd also be happy if just the raw model (https://meet.google.com/_/rtcvidproc/release/336842817/segm_lite_v509.tflite) was released under a permissive license - I can figure out the model structure and JavaScript wiring :-)

@jameshfisher I have successfully deployed the raw tflite model (BTW. many thanks for the link!) within a desktop app using MediaPipe. But I failed to do so for web app, since MediaPipe doesn't have any documentation for it yet (just some JS API's for specific examples, but not for custom models). But it looks like you're saying that you did it. How? Have you extracted the layers of the model + weights and "manually" created the same TF model and then converted it to TFJS? Or have you managed to compile the tflite to wasm and use MediaPipe? Many thanks!

kirawi commented 3 years ago

@stanhrivnak I found this while looking into it myself: https://gist.github.com/tworuler/bd7bd4c6cd9a8fbbeb060e7b64cfa008 Unfortunately, I'm not familiar with tensorflow (sad Amd gpu gang), so I have no idea how it works or how to modify it. PINTO0309 uses modified versions of that script for his tflite -> pb scripts.

PINTO0309 commented 3 years ago

I have generated and committed models for .pb, .tflite float32/float16, INT8, EdgeTPU, TFJS, TF-TRT, CoreML, and OpenVINO IR for testing. However, I was so exhausted that I did not create a test program to test it. I would be very happy if you could test it with your help. :smiley: https://github.com/PINTO0309/PINTO_model_zoo/tree/master/082_MediaPipe_Meet_Segmentation

If there are any licensing issues, I'm going to delete it.

kirawi commented 3 years ago

I have generated and committed models for .pb, .tflite float32/float16, INT8, EdgeTPU, TFJS, TF-TRT, CoreML, and OpenVINO IR for testing. However, I was so exhausted that I did not create a test program to test it. I would be very happy if you could test it with your help. 😃 https://github.com/PINTO0309/PINTO_model_zoo/tree/master/082_MediaPipe_Meet_Segmentation

If there are any licensing issues, remove it.

Amazing work!

PINTO0309 commented 3 years ago

There was a Japanese engineer who implemented it in TFJS. There still seems to be a little problem with the conversion. It gets shifted to the left. Also, there is no smoothing post-processing called "light wrapping", so the border is jagged.

https://user-images.githubusercontent.com/33194443/103107920-4d37da00-4686-11eb-9129-8e34e272e638.mp4

kirawi commented 3 years ago

Is the shifting fixable?

PINTO0309 commented 3 years ago

I'm using my own tricks in the optimization phase, so that may be affecting the results. Please give me some time so I can try this out.

PINTO0309 commented 3 years ago

Is the shifting fixable?

It worked. However, the model resolution of 128x128 does not seem to be very accurate. test (コピー 1) out1

kirawi commented 3 years ago

That's unfortunate, but nonetheless amazing work man!

kirawi commented 3 years ago

Ah wait, I think that is intentional to reduce the computational requirements of the model. The bilateral filter mentioned in the blog further refines the mask, and it might be the case that the model works best with bright colours. I think all things considered, the model does its job fairly well. By the way, mind sharing the test setup you have for the model?

PINTO0309 commented 3 years ago

@kirawi I did not use bilateral filter and just binarized the image, so the result may not be good.

### Download test.jpg
$ sudo gdown --id 1Tyv6P2zshOCqTgYBLoa0aC3Co8W-9JPG

### Download segm_lite_v509_128x128_float32.tflite
$ sudo gdown --id 1qOlcK8iKki_aAi_OrxE2YLaw5EZvQn1S
import numpy as np
from PIL import Image
try:
    from tflite_runtime.interpreter import Interpreter
except:
    from tensorflow.lite.python.interpreter import Interpreter

img = Image.open('test.jpg')
h = img.size[1]
w = img.size[0]
img = img.resize((128, 128))
img = np.asarray(img)
img = img / 255.
img = img.astype(np.float32)
img = img[np.newaxis,:,:,:]

# Tensorflow Lite
interpreter = Interpreter(model_path='segm_lite_v509_128x128_float32.tflite', num_threads=4)
interpreter.allocate_tensors()
input_details = interpreter.get_input_details()[0]['index']
output_details = interpreter.get_output_details()[0]['index']

interpreter.set_tensor(input_details, img)
interpreter.invoke()
output = interpreter.get_tensor(output_details)

print(output.shape)
out1 = output[0][:, :, 0]
out2 = output[0][:, :, 1]

out1 = (out1 > 0.5) * 255
out2 = (out2 > 0.5) * 255

print('out1:', out1.shape)
print('out2:', out2.shape)

out1 = Image.fromarray(np.uint8(out1)).resize((w, h))
out2 = Image.fromarray(np.uint8(out2)).resize((w, h))

out1.save('out1.jpg')
out2.save('out2.jpg')
w-okada commented 3 years ago

I create the demo page to use PINTO's model converted to tensorflowjs.

https://flect-lab-web.s3-us-west-2.amazonaws.com/P01_wokers/t11_googlemeet-segmentation/index.html

You can change input device with control panel at right side. If you want to use your camera device, please try.

And at default this page use new version of PINTO's model, but it seems shift to left a little yet...

You can change the model to old version of PINTO's model with the control panel at right side too. Select modelPath and click reload model button.

PINTO0309 commented 3 years ago

I overlaid the image with the tflite implementation at hand. Does it shift when I apply the filter?

https://user-images.githubusercontent.com/33194443/103143865-68781780-4762-11eb-864d-3594c61e8972.mp4

kirawi commented 3 years ago

I don't think it's shifting, it looks more like the one with the white background is capturing more of the background than the other one.

PINTO0309 commented 3 years ago

@kirawi I am currently investigating this issue in collaboration with @w-okada on twitter.

w-okada commented 3 years ago

mmmm, I spent a lot of time to solve the "shifting" problem yesterday. However, I couldn't. Can anybody help me? This is my simple test code with nodejs.

const tf = require('@tensorflow/tfjs-node');
const fs = require('fs');
const jpeg = require('jpeg-js');
const { createCanvas, loadImage } = require('canvas')

const readImage = path => {
    const buf = fs.readFileSync(path)
    const pixels = jpeg.decode(buf, true)
    return pixels
}

const imageByteArray = (image, numChannels) => {
    const pixels = image.data
    const numPixels = image.width * image.height;
    const values = new Int32Array(numPixels * numChannels);

    for (let i = 0; i < numPixels; i++) {
      for (let channel = 0; channel < numChannels; ++channel) {
        values[i * numChannels + channel] = pixels[i * 4 + channel];
      }
    }  
    return values
}

const main = async()=>{
    const image = readImage("test.jpg")
    const handler = tf.io.fileSystem("./model/model.json");
    const model = await tf.loadGraphModel(handler)
    const numChannels=3
    const values = imageByteArray(image, numChannels)
    const outShape = [image.width, image.height, numChannels];
    let input = tf.tensor3d(values, outShape, 'float32');

    input = tf.image.resizeBilinear(input,[128, 128])
    input = input.expandDims(0)
    input = tf.cast(input, 'float32')
    input = input.div(tf.max(input))

    let predict = await model.predict(input)
    predict = predict.softmax()
    const res = await predict.arraySync()
    const bm = res[0]
    const width = bm[0].length
    const height = bm.length
    const canvas = createCanvas(width, height)
    const imageData = canvas.getContext("2d").getImageData(0, 0, canvas.width, canvas.height)
    for (let rowIndex = 0; rowIndex < canvas.height; rowIndex++) {
        for (let colIndex = 0; colIndex < canvas.width; colIndex++) {
            const pix_offset = ((rowIndex * canvas.width) + colIndex) * 4
            if(bm[rowIndex][colIndex][0]>0.5){
                imageData.data[pix_offset + 0] = 255
                imageData.data[pix_offset + 1] = 0
                imageData.data[pix_offset + 2] = 0
                imageData.data[pix_offset + 3] = 128
            }else{
                imageData.data[pix_offset + 0] = 0
                imageData.data[pix_offset + 1] = 0
                imageData.data[pix_offset + 2] = 0
                imageData.data[pix_offset + 3] = 128
            }
        }
    }
    // const imageDataTransparent = new NodeCanvasImageData(data, this.canvas.width, this.canvas.height);
    canvas.getContext("2d").putImageData(imageData, 0, 0)

    const tmpCanvas = createCanvas(image.width, image.height)
    tmpCanvas.getContext("2d").drawImage(canvas, 0, 0, tmpCanvas.width, tmpCanvas.height)
    const buf = tmpCanvas.toBuffer('image/png')
    fs.writeFileSync('./res.png', buf)
}

main()

test res

stanhrivnak commented 3 years ago

Hi guys, first of all, many thanks to @PINTO0309, @w-okada, and others for putting your effort on this! Great work so far! I would really love to have this great model from google in my web app (currently I have bodypix with custom improvements, but still it sucks). Here are my 2 cents. I have deployed the discussed original tflite model (https://meet.google.com/_/rtcvidproc/release/336842817/segm_lite_v509.tflite) within a desktop app using MediaPipe and it performs amazingly (see the attached video) even under not optimal light conditions. What you see is the raw model performance without any post-processing (with it, it looks even better), resolution 128 x 128. https://user-images.githubusercontent.com/64148065/103182841-d2053c80-48ae-11eb-8ba1-1a1518c9defb.mov

The implications are:

  1. There is hope - the model is already good enough, the resolution 128 x 128 is high enough to have nice results when upsampling to SD/HD. Also, it's super-fast, inferences running well above 25 FPS.
  2. There has to be a flaw in the manual conversion to h5/TFJS.

I think the best would be to compare the outputs of the original tflite model and the created TFJS model (or h5/tflite), layer after layer to see where it deviates and focus to fix that part. The problem is that the original tflite model uses some custom ops, so it can't be read in python directly. But we know the definitions of these ops, here they are: (not sure if it uses all 3, but at least "Convolution2DTransposeBias", because that is the error it gives me in python) https://github.com/google/mediapipe/tree/master/mediapipe/util/tflite/operations The problem is that it's in C++, so it has to be rewritten to python or we need to go with Tensorflow C++. Also, as stated here: https://github.com/google/mediapipe/issues/35#issuecomment-630022641 these custom ops are just merged existing operations, so it should be straight-forward.

So this is my plan. I can work on it only ~ 2 hours a day, so if you're faster, go for it and let me know! :) Or if you have any other ideas, share it please!

PINTO0309 commented 3 years ago

@stanhrivnak I have already succeeded in replacing custom operations. You're right, it would be quicker to check the results of the output for each layer, but I don't have enough time to do that since I'm also working on converting other models at the same time.

https://github.com/PINTO0309/PINTO_model_zoo/blob/32f1a821bc3c8a04a53ba3e18a45921a136de889/082_MediaPipe_Meet_Segmentation/01_segm_lite_tflite2h5_weight_int_fullint_float16_quant.py#L691-L704

stanhrivnak commented 3 years ago

@PINTO0309 Unfortunately, tflite format doesn't allow accessing intermediate results after each operation/layer, just the final output node... so we can't debug your code this way... @jasonmayes could you kindly provide information on when can we expect the release of the TFJS version of the model? Will it be in the order of weeks or months or "definitely not soon"? This information will greatly help us in our planning. Many thanks in advance!

simon-lanf commented 3 years ago

@w-okada

https://flect-lab-web.s3-us-west-2.amazonaws.com/P01_wokers/t11_googlemeet-segmentation/index.html

Could you publish the code for this page please ? Thank you.

marcelgoya commented 3 years ago

@simon-lanf You should be able to get it by simply opening the referenced JS/TSX files. Google DevTools is your friend here ....

floe commented 3 years ago

@w-okada this is entirely off-topic, but I just have to ask - was the picture in your post taken in Z10, by any chance?

w-okada commented 3 years ago

@floe I don't know. I just used the picture PINTO provided above post.

$ sudo gdown --id 1Tyv6P2zshOCqTgYBLoa0aC3Co8W-9JPG
w-okada commented 3 years ago

@simon-lanf

This code is in my dev-branch. You can see at (or clone from) https://github.com/w-okada/image-analyze-workers/tree/dev/011demo_googlemeet-segmentation-worker-js-demo

floe commented 3 years ago

Oh, now I see, the image is from PASCAL VOC. Sorry for the noise.

floe commented 3 years ago

JFYI, I have a C++ TFLite implementation using the Google Meet model for background segmentation: https://github.com/floe/deepbacksub

PINTO0309 commented 3 years ago

Since I was introduced to a full-size model, I will try to quantize it, including converting custom operations.

144x256 https://meet.google.com/_/rtcvidproc/release_1wttl/345264209/segm_full_v679.tflite

simon-lanf commented 3 years ago

Can anyone tell if this one is different from v679 ?

https://meet.google.com/_/rtcvidproc/release_1wttl/345264209/segm_lite_v681.tflite

floe commented 3 years ago

@simon-lanf AFAICT it's the same model, just the resolution is different.

kirawi commented 3 years ago

That one is 96x160, I think

jiangjianping commented 3 years ago

@tafsiri

Is there anything about the joint bilateral filter used in Google Meet? Which is the guide image? Thanks.

PINTO0309 commented 3 years ago

I replaced the custom OPs of the full-size model with standard OPs, and further converted them with my own optimization. I have not implemented any post-processing, but I think it performs quite well. The bilateral filter is not used.

I have also converted as much as possible for the various frameworks. If you run a TFJS model and experience misalignment, it is a problem with the TFJS runtime.

Screenshot 2021-01-05 16:02:46

### Download test.jpg
$ sudo gdown --id 1Tyv6P2zshOCqTgYBLoa0aC3Co8W-9JPG

### Download segm_full_v679_144x256_opt_float32.tflite
$ sudo gdown --id 1tKhwGLJ3f0GYDAWFiufv0e7DGVfW6ztS
import numpy as np
from PIL import Image
try:
    from tflite_runtime.interpreter import Interpreter
except:
    from tensorflow.lite.python.interpreter import Interpreter

img = Image.open('test.jpg')
h = img.size[1]
w = img.size[0]
img = img.resize((256, 144))
img = np.asarray(img)
img = img / 255.
img = img.astype(np.float32)
img = img[np.newaxis,:,:,:]

# Tensorflow Lite
interpreter = Interpreter(model_path='segm_full_v679_144x256_opt_float32.tflite', num_threads=4)
interpreter.allocate_tensors()
input_details = interpreter.get_input_details()[0]['index']
output_details = interpreter.get_output_details()[0]['index']

interpreter.set_tensor(input_details, img)
interpreter.invoke()
output = interpreter.get_tensor(output_details)

print(output.shape)
out1 = output[0][:, :, 0]
out2 = output[0][:, :, 1]

out1 = (out1 > 0.5) * 255
out2 = (out2 > 0.5) * 255

print('out1:', out1.shape)
print('out2:', out2.shape)

out1 = Image.fromarray(np.uint8(out1)).resize((w, h))
out2 = Image.fromarray(np.uint8(out2)).resize((w, h))

out1.save('out1.jpg')
out2.save('out2.jpg')
PINTO0309 commented 3 years ago

I re-committed, revising the conversion method and also improving the accuracy of the 128x128 Lite model.

Screenshot 2021-01-05 17:17:30

floe commented 3 years ago

@PINTO0309 excellent, thank you. Can you briefly summarize what optimizations you used?

w-okada commented 3 years ago

Wow!!! Great. With tfjs, it completely worked!

Demo page is here. You can try it! https://flect-lab-web.s3-us-west-2.amazonaws.com/P01_wokers/t11_googlemeet-segmentation/index.html

https://user-images.githubusercontent.com/48346627/103629625-ea96e600-4f83-11eb-832b-c2d69ed8c228.mp4

amiregelz commented 3 years ago

@w-okada This is amazing!

w-okada commented 3 years ago

With wasm, I get the image like below. Ummmm.

image

PINTO0309 commented 3 years ago

@floe

I used the following trick.

  1. Fused bias, weight, and activation functions (ReLU/ReLU6) into Convolution, FullyConnected, and DepthwiseConvolution.
  2. Since the tflite model published by Google is quantized to Float16, I dared to temporarily convert it to Float32 to support conversion to various frameworks.
  3. In order to quantize INT8 and run it on a fast inference device called EdgeTPU, I made my own modifications to Hard-Swish.
    ### For TFJS, TFLite, TF-TRT, OpenVINO
    hswish = x * tf.nn.relu6(x + 3) * 0.16666667
    ### For EdgeTPU
    hswish = x * tf.nn.relu6(x + 3) * 0.16666666
  4. Because of the problems with TensorFlow's ResizeBilinear, I did my own little trick.
jiangjianping commented 3 years ago

@w-okada . Excellent and beautiful! which post-process do you use?

simon-lanf commented 3 years ago

@w-okada

Yeah I can reproduce it too, I can confirm that in WASM the results are different for the same images.

kirawi commented 3 years ago

Quick hacky joint bilateral filter. I know nothing about this, but it seems to work. Interestingly, out1 seems to be more accurate than out2.

import numpy as np
import cv2
try:
    from tflite_runtime.interpreter import Interpreter
except:
    from tensorflow.lite.python.interpreter import Interpreter

img = cv2.imread('Capture.png')
h = img.shape[0]
w = img.shape[1]

img = cv2.resize(img, (256, 144))
img = np.asarray(img)
img = img / 255.
img = img.astype(np.float32)
img = img[np.newaxis,:,:,:]

# Tensorflow Lite
interpreter = Interpreter(model_path='model_float16_quant.tflite', num_threads=4)
interpreter.allocate_tensors()
input_details = interpreter.get_input_details()[0]['index']
output_details = interpreter.get_output_details()[0]['index']

interpreter.set_tensor(input_details, img)
interpreter.invoke()
output = interpreter.get_tensor(output_details)

print(output.shape)
out1 = output[0][:, :, 0]
out2 = output[0][:, :, 1]

out1 = np.invert((out1 > 0.5) * 255)
out2 = np.invert((out2 > 0.5) * 255)

print('out1:', out1.shape)
print('out2:', out2.shape)

out1 = cv2.resize(np.uint8(out1), (w, h))
out2 = cv2.resize(np.uint8(out2), (w, h))

cv2.imwrite('out1.jpg', out1)
cv2.imwrite('out2.jpg', out2)

out3 = cv2.ximgproc.jointBilateralFilter(out2, out1, 8, 75, 75)

cv2.imwrite('out3.jpg', out3)

Capture out5

jiangjianping commented 3 years ago

@kirawi Interesting. Why do you use the out2 as guide image?