tensorflow / tfjs

A WebGL accelerated JavaScript library for training and deploying ML models.
https://js.tensorflow.org
Apache License 2.0
18.43k stars 1.92k forks source link

Bug in deserializing TF2.0 GRU layer's bias vector in tf.loadLayersModel() #2442

Closed philipp-ghtest closed 4 years ago

philipp-ghtest commented 4 years ago

Dear tfjs-team,

I ran into an issue these days when I tried to import a former Keras .h5 model with GRU layers into tfjs, see https://github.com/tensorflow/tfjs/issues/2437

tldr: this error popped up:

 Uncaught (in promise) Error: Shape mismatch: [384] vs. [2,384]
     at variables.ts:135
     at t.write (variables.ts:98)
     at variables.ts:339
     at Array.forEach (<anonymous>)
     at sf (variables.ts:337)
     at e.loadWeights (container.ts:598)
     at models.ts:315
     at common.ts:14
     at Object.next (common.ts:14)
     at a (common.ts:14)

On further investigation, I figured out that there's something wrong in models.ts:300 in the deserialize() function.

Turns out that the model object returned by deserialize() sets faulty shapes for the bias vectors of GRU layers (those should be of shape [2, x] but are set to [x], here [384):

tfjs_debug3

On the contrary, the weights loaded in models.ts:313 by io.decodeWeights() are set correctly (here [2, 384]):

tfjs_debug2

So there must be something wrong with deserialize() or some nested functions. I really tried to dig further, but I'm basically totally foreign to JS/TS so it's really hard for me to figure it out any further.

This bug should be easy to reproduce, just create some model with GRU layers in Keras, like this:

model = keras.models.Sequential([
#    keras.layers.GRU(128, return_sequences=True, batch_input_shape=[batch_size, None, max_id+1]),
    keras.layers.GRU(128, return_sequences=True, input_shape=[ None, max_id+1]),
    keras.layers.GRU(128, return_sequences=True),
    keras.layers.GRU(128),
    keras.layers.Flatten(),
    keras.layers.Dense(output_size, activation="softmax")
])

I guess you don't even need to train it, just initializing it should be fine.

Then convert it with the tfjs-converter, and load it with tf.loadLayersModel

I'd be really grateful for any fixes or quick workarounds. Thank you in advance!

philipp-ghtest commented 4 years ago

Hey, any update on this?

caisq commented 4 years ago

@philipp-ghtest Sorry for the delay in getting back to you.

I looked at the Python code that you used to generate the model. Why are you including a keras.layers.Flatten() layer? The output of the 3rd (last) GRU layer already has a flat output shape (None, 128). If you remove that layer, things should work properly in TensorFlow.js.

caisq commented 4 years ago

I close this issue for now. If you have any additional questions or concerns, please feel free to keep adding comments and/or re-open this issue. @philipp-ghtest

ckhordiasma commented 4 years ago

I am having this exact same issue. I am also using a GRU, in my case my model looks like

model = tf.keras.Sequential([
  tf.keras.layers.Embedding(vocab_size, embedding_dim,
                              batch_input_shape=[batch_size, None]),
  tf.keras.layers.GRU(rnn_units,
                        return_sequences=True,
                        stateful=True,
                        recurrent_initializer='glorot_uniform'),
  tf.keras.layers.Dense(vocab_size)
])

I am importing using keras and the tensorflowjs command line tool from pip, tensorflowjs_converter --input_format keras ./model.h5 ./output

when I try to import my model into tensorflow js using the command: const model = tf.loadLayersModel('/output/model.json')

I get the error (in the browser, I'm working on client side): Uncaught (in promise) Error: Shape mismatch: [3072] vs. [2,3072]

I have tried converting using the python tool instead of the command line tool

import tensorflowjs as tfjs
import tensorflow as tf
model = tf.keras.models.load_model('../model.h5')
tfjs.converters.save_keras_model(model, './output/')

but that has also not worked.

please let me know if there is any more info I can provide.

(note: unlike the other guy, I don't have a flatten layer after the GRU --I am pretty much just copying the shakespeare tutorial from the tensorflow website-- so I don't think the flatten layer is the issue)

update: i re-trained my model using an LSTM instead of a GRU, and I did not get the same error I got before. Specifically, I changed my model to

model = tf.keras.Sequential([
  tf.keras.layers.Embedding(vocab_size, embedding_dim,
                            batch_input_shape=[batch_size, None]),
  tf.keras.layers.LSTM(rnn_units,return_sequences=True,),
  tf.keras.layers.Dense(vocab_size)
])

I think this helps confirm that there is something wrong with how tensorflow.js is handling GRU keras models.

philipp-ghtest commented 4 years ago

@philipp-ghtest Sorry for the delay in getting back to you.

I looked at the Python code that you used to generate the model. Why are you including a keras.layers.Flatten() layer? The output of the 3rd (last) GRU layer already has a flat output shape (None, 128). If you remove that layer, things should work properly in TensorFlow.js.

Nope, it doesn't. I've already tried this.

As I extensively described in my first post, this is due to a shape mismatch when setting the bias weights of the GRU layers in tfjs.

I'm still trying to track this issue down, but up to date I'm pretty sure there's something wrong with setting/initializing weights in tfjs' GRU layers.

philipp-ghtest commented 4 years ago

@ckhordiasma

Thanks for your reply. This is indeed the same issue as mine. I'm glad I'm not the only one.

philipp-ghtest commented 4 years ago

@ckhordiasma @caisq

I found a workaround that works for me.

I figured that in the bias weights that get saved from keras that have the shape [2, 384] (the 384 is case specific of course), that in these weights, the [0] column has exactly the same values as the [1] column. They're redundant. So that's more likely a bug in tf's python code?

Anyway, if you manually set up the architecture, and then manipulate the weights by cropping out the redundant part of the bias from [2, 384] to [384], it works!!

       //build arch manually first
    const model = tf.sequential();
    model.add(tf.layers.gru({ units: 128 , returnSequences: true, inputShape: [null, 63], name: 'gru'}));
    model.add(tf.layers.gru({ units: 128 , returnSequences: true, name: 'gru_1'}));

    model.add(tf.layers.gru({ units: 128 , name: 'gru_2'}));
    //model.add(tf.layers.flatten());
    model.add(tf.layers.dense({ units: 3732, activation: 'sigmoid', name: 'dense'}));

    model.summary();

       //load weights manually
    const path = 'http://localhost:3000/prototype_catcross_web_Kopie/model.json'
    const handlers = tf.io.getLoadHandlers(path);
        if (handlers.length === 0) {
        handlers.push(tf.io.browserHTTPRequest(path, options));
        }
    handler = handlers[0];
    (async () => {
        const artifacts = await handler.load();
        const weights =
            tf.io.decodeWeights(artifacts.weightData, artifacts.weightSpecs);

         // crop weights to correct shape
                bias = weights["gru/bias"].as1D().slice([0],[384]);
        bias_1 = weights["gru_1/bias"].as1D().slice([0],[384]);
        bias_2 = weights["gru_2/bias"].as1D().slice([0],[384]);
        weights["gru/bias"] = bias
        weights["gru_1/bias"] = bias_1;
        weights["gru_2/bias"] = bias_2;
        model.loadWeights(weights, true);

        const example = tf.tensor3d(// stupidly long tensor for testing ]]);
        const prediction = model.predict(example);

                // correct result is found!!!
                argmax = prediction.squeeze().argMax();
        document.getElementById("demo").innerHTML = argmax;
        document.getElementById("summary").innerHTML = prediction;
}   )();

So this workaround is quite ugly, but it clearly shows that it's because of the weights.

Sorry I can't post a more elaborate fix, since I'm really not good at TS/JS and absolutely not familiar with tfjs code.

EDIT: I just figured out that the seemingly identical columns [0] and [1] from the [2, 384] bias weight vector are not completely identical, they vary in their last third. I didn't notice this because my results were still good enough. But still, this is clearly an issue with the GRU layers handling of the bias.

philipp-ghtest commented 4 years ago

@caisq can you maybe reopen this and label it as a bug and not as support? That would really be great, thanks!

edit: thank you

caisq commented 4 years ago

@philipp-ghtest and @ckhordiasma I realized I should have asked you for this basic piece of info before: what versions of tensorflow and tensorflowjs (in python) and @tensorflow/(tfjs|tfjs-node) (in JavaScript) are you using?

I just did a quick testing of @ckhordiasma 's GRU model with tensorflow 1.15.0, tensorflowjs 1.1.5, and @tensorflow/tfjs-node 1.3.2 and the model saving / conversion and loading seem to be working properly.

ckhordiasma commented 4 years ago

Im traveling and don’t have access to my computer for more specifics, but I know I’m using tensorflow 2.0 and the most recent version of tensorflowjs on pypi for the import

On Wed, Dec 11, 2019 at 1:49 PM Shanqing Cai notifications@github.com wrote:

@philipp-ghtest https://github.com/philipp-ghtest and @ckhordiasma https://github.com/ckhordiasma I realized I should have asked you for this basic piece of info before: what versions of tensorflow and tensorflowjs (in python) and tensorflow.js (in JavaScript) are you using?

I just did a quick testing of @ckhordiasma https://github.com/ckhordiasma 's GRU model with tensorflow 1.15.0, tensorflowjs 1.1.5, and @tensorflow/tfjs-node 1.3.2 and the model saving / conversion and loading seem to be working properly.

— You are receiving this because you were mentioned. Reply to this email directly, view it on GitHub https://github.com/tensorflow/tfjs/issues/2442?email_source=notifications&email_token=AB53AMULCQICUHPXHPDS2KTQYEY4NA5CNFSM4JSI5OFKYY3PNVWWK3TUL52HS4DFVREXG43VMVBW63LNMVXHJKTDN5WW2ZLOORPWSZGOEGUFR4I#issuecomment-564680945, or unsubscribe https://github.com/notifications/unsubscribe-auth/AB53AMUN2GIZ4L7AVYMTWDDQYEY4NANCNFSM4JSI5OFA .

philipp-ghtest commented 4 years ago

@caisq I trained and exported the model with Tensorflow 2.0.0, and converted it with the tensorflowjs_converter from tensorflowjs 1.3.2., both installed via pip

In JS, I tried both tfjs 1.0.0 and 1.3.2, imported by script tag, but neither works

<script src="https://cdn.jsdelivr.net/npm/@tensorflow/tfjs@1.3.2/dist/tf.min.js"></script>

caisq commented 4 years ago

@philipp-ghtest I can verify that this error exists for models saved by tensorflow 2.0.0. I'll continue investigating this issue.

The workaround for you folks right now is to use tensorflow 1.15 to saved the model (followed by conversion and model loading in JS). Sorry for the inconvenience.

caisq commented 4 years ago

Changing priority to P2 as there is a workaround available.

This seems to be originated from a change in the behavior of the GRU layer (specifically the GRUCell class) between tf 1.5 and 2.0. In the older version, the bias has a shape of (9,), in the latter, it has a shape of (2, 9), which is a stacking of two biases: recurrent_bias (input_bias and recurrent_bias). TensorFlow.js hasn't caught up with this change yet.

cc @fchollet @karmel

philipp-ghtest commented 4 years ago

Thank you very much for looking into it and for the workaround.

Unfortunately, the workaround only works if you retrain the model on tf1.1x. I tried loading my saved .h5 model into my python training script, and exporting it again for conversion with tfjs.

But on model.load_weights(os.path.join(data_dir, "prototype.h5")) I get the following error:

ValueError: GRU(reset_after=True) is not compatible with GRU(reset_after=False)

So, I tried again, this time with setting reset_after=True:

model = keras.models.Sequential([
#    keras.layers.GRU(128, return_sequences=True, batch_input_shape=[batch_size, None, max_id+1]),
    keras.layers.GRU(128, return_sequences=True, input_shape=[ None, max_id+1], reset_after=True),
    keras.layers.GRU(128, return_sequences=True, reset_after=True),
    keras.layers.GRU(128, reset_after=True),
    keras.layers.Flatten(),
    keras.layers.Dense(output_size, activation="softmax")
])

This time, loading my saved weights works fine. But when I export this configuration to TFJS, I get the same error in TFJS:

Uncaught (in promise) Error: Shape mismatch: [384] vs. [2,384]

So I'm retraining my model right now, I can tell you if the workaround works if it's done

EDIT: loading the retrained model in TFJS works. thanks again to everyone here

elazarg commented 4 years ago

reset_after=True doesn't work for me:

import tensorflow as tf

USE_BIAS = True  # Fails iff True
model = tf.keras.Sequential([
    tf.keras.layers.Embedding(5, 7),
    tf.keras.layers.GRU(3, use_bias=USE_BIAS, reset_after=True)
])

import tensorflowjs as tfjs
tfjs.converters.save_keras_model(model, 'bad_js/')

Version 1.6.0

psaikko commented 4 years ago

This seems to still be current. I am seeing this same shape mismatch issue with

TensorFlow.js version @tensorflow/tfjs@1.7.4/dist/tf.min.js

python packages: tensorflow 2.1.0 tensorflowjs 1.7.4.post1

I am saving my model with tfjs.converters.save_keras_model

Uncaught (in promise) Error: Shape mismatch: [192] vs. [2,192]
    at variables.ts:135
    at t.write (variables.ts:98)
    at variables.ts:339
    at Array.forEach (<anonymous>)
    at od (variables.ts:337)
    at e.loadWeights (container.ts:628)
    at models.ts:318
    at common.ts:14
    at Object.next (common.ts:14)
    at o (common.ts:14)

If I replace the GRU layer with a SimpleRNN layer in the model, it loads correctly in tfjs.

edit: possibly being fixed by https://github.com/tensorflow/tfjs/pull/3377

GUR9000 commented 4 years ago

Can someone explain why the GRU would need twice as many biases in TF2.0 compared to all other implementations of GRUs (origninal paper, TF1, other frameworks, etc.? Is this a TF2 bug?

The original bias tensor already contained the non-recurrent and recurrent parts, so the (2, XX) seems superfluous.