transcranial / keras-js

Run Keras models in the browser, with GPU support using WebGL
https://transcranial.github.io/keras-js
MIT License
4.96k stars 503 forks source link

Is there a way to run batch normalization layer in training mode(instance normalization)? #114

Open MingwangLin opened 6 years ago

MingwangLin commented 6 years ago

Hi, I found batch normalization layer in keras.js can only be run in inference mode(use moving mean and moving variance). Is there a way to run batch normalization layer in training mode(use batch statistics)? Thanks!

MingwangLin commented 6 years ago

Let me try replying by myself.

After days of source codes reading, I find there's actually a way to run batch normalization layer in training mode, in other words, implement instance norm layer.

It needs computing mean and variance of input on the fly and replacing original moving mean and variance data with computed mean and variance. Codes below is my implementation, they need to be added before this line.

// Transfer data from webgl texture on GPU to ndarray on CPU
x.transferFromGLTexture()

// Compute mean and variance of input ndarray
const channels = x.tensor.shape[1]
const channelDataSize = x.tensor.shape[0]
let channelDataRaveled = ndarray(new x.arrayType(channelDataSize), [channelDataSize])
let xMean = []
let xVariance = []
for (let i = 0; i < channels; i++) {
    ops.assign(channelDataRaveled, x.tensor.pick(null, i))

    const mean = ops.sum(channelDataRaveled) / channelDataSize
    xMean.push(mean)

    ops.subseq(channelDataRaveled, mean)
    ops.powseq(channelDataRaveled, 2)
    const variance = ops.sum(channelDataRaveled) / channelDataSize
    xVariance.push(variance)
}

// replace old moving mean and variance with computed mean and variance
this.weights['moving_mean'].replaceTensorData(xMean)
this.weights['moving_variance'].replaceTensorData(xVariance)

But this implementation needs to tranfer data from webgl texture on GPU to ndarray on CPU, then processed in CPU by codes above. It slow down performance a little too much. Take my image transfer tasks as example, an image generation now needs 7--8 seconds , before it only need 1--2 seconds.

Maybe it can be faster by reimplement codes above in webgl, but I'm not familiar with webgl.

Hope anyone can find a better implementation.