tensorflow / tfjs

A WebGL accelerated JavaScript library for training and deploying ML models.
https://js.tensorflow.org
Apache License 2.0
18.36k stars 1.92k forks source link

layersModel working for inference but breaks when training #5381

Open stanleyjzheng opened 3 years ago

stanleyjzheng commented 3 years ago

System information

Describe the current behavior When loading a Keras-converted model using tf.loadLayersModel, training causes a shape mismatch not observed during inference. Commonly occurs around globalAveragePooling2d and globalMaxPooling2d.

Standalone code to reproduce the issue A minimal example would be to convert mobilenet (I understand this is a part of tfjs-models, but as a demonstration) from Tensorflow application's implementation into tensorflow.js, and try to train on a random input. We first save the model, then convert it with tensorflowjs_converter.

To save mobilenet with python:

import tensorflow as tf

mobilenet = tf.keras.applications.mobilenet.MobileNet(include_top=True, weights='imagenet')
mobilenet.save('mobilenet.h5')

Then to convert it to tensorflow.js: tensorflowjs_converter --input_format=keras mobilenet.h5 mobilenet

Finally, we can load mobilenet into JavaScript and observe the error

const tf = require("@tensorflow/tfjs-node");

async function test() {
    const model = await tf.loadLayersModel("file://./mobilenet/model.json");
    model.compile({loss: 'categoricalCrossentropy', optimizer: 'sgd'});
    console.log(model.predict(tf.randomNormal([1, 3, 224, 224])));
    await model.fit(tf.randomNormal([64, 3, 224, 224]), tf.randomNormal([64, 1000]), {batchSize: 4});
}

test();

Output:

Tensor {
  kept: false,
  isDisposedInternal: false,
  shape: [ 1, 1000 ],
  dtype: 'float32',
  size: 1000,
  strides: [ 1000 ],
  dataId: {},
  id: 2154,
  rankType: '2',
  scopeId: 1
}
Epoch 1 / 1
TF_Status: 3
Message: Input to reshape is a tensor with 200704 values, but the requested shape has 1280

Inference works, but not training. The origin of the shape 200704 is the shape of the input of the globalAveragePooling2d layer - batch_size * 1024 * 7 * 7. In python, shapes are all identical and work in training.

This is observed around globalAveragePooling2d and globalMaxPooling2d in many models - including the official Tensorflow Applications implementations of Mobilenetv2 and Efficientnet, though no example is provided since their implementations are much longer.

Other info / logs Traceback:

Message: Input to reshape is a tensor with 2007040 values, but the requested shape has 1280
    at NodeJSKernelBackend.executeSingleOutput (/home/stanleyzheng/kds/kds-melanoma/test/node_modules/@tensorflow/tfjs-node/dist/nodejs_kernel_backend.js:211:43)
    at Object.kernelFunc (/home/stanleyzheng/kds/kds-melanoma/test/node_modules/@tensorflow/tfjs-node/dist/kernels/Reshape.js:33:27)
    at kernelFunc (/home/stanleyzheng/kds/kds-melanoma/test/node_modules/@tensorflow/tfjs-core/dist/tf-core.node.js:4672:32)
    at /home/stanleyzheng/kds/kds-melanoma/test/node_modules/@tensorflow/tfjs-core/dist/tf-core.node.js:4733:27
    at Engine.scopedRun (/home/stanleyzheng/kds/kds-melanoma/test/node_modules/@tensorflow/tfjs-core/dist/tf-core.node.js:4537:23)
    at Engine.runKernelFunc (/home/stanleyzheng/kds/kds-melanoma/test/node_modules/@tensorflow/tfjs-core/dist/tf-core.node.js:4729:14)
    at Engine.runKernel (/home/stanleyzheng/kds/kds-melanoma/test/node_modules/@tensorflow/tfjs-core/dist/tf-core.node.js:4601:21)
    at reshape_ (/home/stanleyzheng/kds/kds-melanoma/test/node_modules/@tensorflow/tfjs/dist/tf.node.js:4471:19)
    at reshape__op (/home/stanleyzheng/kds/kds-melanoma/test/node_modules/@tensorflow/tfjs/dist/tf.node.js:3744:28)
    at Object.derScale [as scale] (/home/stanleyzheng/kds/kds-melanoma/test/node_modules/@tensorflow/tfjs/dist/tf.node.js:7496:20)
gaikwadrahul8 commented 1 year ago

Hi, @stanleyjzheng

Apologize for the delayed response and I tried to replicate this issue with latest version of @tensorflow/tfjs-node and tensorflowjs@4.6.0 and I got below error message so we'll have to dig more into this issue and we'll update you soon, Thank you for noticing this issue and I really appreciate your efforts and time. Thank you!

CC :@mattsoulanille

Here is error log :

gaikwadrahul-macbookpro:TFJS gaikwadrahul$ node test.js
Platform node has already been set. Overwriting the platform with node.
/Users/gaikwadrahul/Desktop/TFJS/node_modules/@tensorflow/tfjs/node_modules/@tensorflow/tfjs-layers/dist/tf-layers.node.js:273
        var _this = _super.call(this, message) || this;
                           ^

ValueError: Input 0 is incompatible with layer conv_preds: expected ndim=4, found ndim=2
    at new ValueError (/Users/gaikwadrahul/Desktop/TFJS/node_modules/@tensorflow/tfjs/node_modules/@tensorflow/tfjs-layers/dist/tf-layers.node.js:273:28)
    at Layer.assertInputCompatibility (/Users/gaikwadrahul/Desktop/TFJS/node_modules/@tensorflow/tfjs/node_modules/@tensorflow/tfjs-layers/dist/tf-layers.node.js:3020:27)
    at /Users/gaikwadrahul/Desktop/TFJS/node_modules/@tensorflow/tfjs/node_modules/@tensorflow/tfjs-layers/dist/tf-layers.node.js:3263:19
    at nameScope (/Users/gaikwadrahul/Desktop/TFJS/node_modules/@tensorflow/tfjs/node_modules/@tensorflow/tfjs-layers/dist/tf-layers.node.js:978:19)
    at Layer.apply (/Users/gaikwadrahul/Desktop/TFJS/node_modules/@tensorflow/tfjs/node_modules/@tensorflow/tfjs-layers/dist/tf-layers.node.js:3222:16)
    at processNode (/Users/gaikwadrahul/Desktop/TFJS/node_modules/@tensorflow/tfjs/node_modules/@tensorflow/tfjs-layers/dist/tf-layers.node.js:22341:23)
    at Container.fromConfig (/Users/gaikwadrahul/Desktop/TFJS/node_modules/@tensorflow/tfjs/node_modules/@tensorflow/tfjs-layers/dist/tf-layers.node.js:22402:33)
    at deserializeKerasObject (/Users/gaikwadrahul/Desktop/TFJS/node_modules/@tensorflow/tfjs/node_modules/@tensorflow/tfjs-layers/dist/tf-layers.node.js:674:29)
    at deserialize (/Users/gaikwadrahul/Desktop/TFJS/node_modules/@tensorflow/tfjs/node_modules/@tensorflow/tfjs-layers/dist/tf-layers.node.js:20073:12)
    at /Users/gaikwadrahul/Desktop/TFJS/node_modules/@tensorflow/tfjs/node_modules/@tensorflow/tfjs-layers/dist/tf-layers.node.js:25604:29

Node.js v18.15.0