tensorflow / java

Java bindings for TensorFlow
Apache License 2.0
785 stars 193 forks source link

Tensor type issue #505

Closed SaherAlSous closed 8 months ago

SaherAlSous commented 8 months ago

Please make sure that this is a bug. As per our GitHub Policy, we only address code/doc bugs, performance issues, feature requests and build/installation issues on GitHub. tag:bug_template

System information

Describe the current behavior I'm trying to find a way to use tensorflowmodel in Spring Boot, I loaded the model successfully, and created the needed tensor, but I can't make a call to get the result from the model because of this error: Caused by: org.tensorflow.exceptions.TFInvalidArgumentException: Expects arg[0] to be float but uint8 is provided I checked the model signature and it was like this: Signature for "serving_default":

` Method: "tensorflow/serving/predict" Inputs: "input_1": dtype=DT_FLOAT, shape=(-1, 299, 299, 3) Outputs: "dense_3": dtype=DT_FLOAT, shape=(-1, 41)

Signature for "__saved_model_init_op": Outputs: "__saved_model_init_op": dtype=DT_INVALID, shape=() `

my tensor details are DT_UINT8 tensor with shape [299, 299, 3]. When I changed my tensor data type into float like this: val imageShape = TFloat32.tensorOf(runner.fetch(decodeImage).run()[0].shape()) val reshape = tf.reshape( decodeImage, tf.array( -1.0f, imageShape[0].getFloat(), imageShape[1].getFloat(), imageShape[2].getFloat()) ) I got this error: org.tensorflow.exceptions.TFInvalidArgumentException: Value for attr 'Tshape' of float is not in the list of allowed values: int32, int64

Describe the expected behavior Doing inference / expectation succesffuly.

Code to reproduce the issue Provide a reproducible test case that is the bare minimum necessary to generate the problem. Loading the model in TFServices: fun model(): SavedModelBundle { return SavedModelBundle .loader("/home/***/src/main/resources/pd/") .withRunOptions(RunOptions.getDefaultInstance()) .load() } Building the Tensor and calling the model

        val graph = Graph()
        val session = Session(graph)
        val tf = Ops.create(graph)
        val fileName = tf.constant("/home/***/src/main/resources/keyframe_1294.jpg")
        val readFile = tf.io.readFile(fileName)
        val runner = session.runner()
        val decodingOptions = DecodeJpeg.channels(3)
        val decodeImage = tf.image.decodeJpeg(readFile.contents(), decodingOptions)
        val imageShape = runner.fetch(decodeImage).run()[0].shape()
        val reshape = tf.reshape(
            decodeImage,
            tf.array(
                -1,
                imageShape.asArray()[0],
                imageShape.asArray()[1],
                imageShape.asArray()[2])
            )
        val tensor = runner.fetch(reshape).run()[0]
        val inputMap = mutableMapOf("input_tensor" to tensor)
        println(tensor.shape())
        println(tensor.dataType())
        println(tensor.asRawTensor())
        val result = tfService.model().function("serving_default").call(inputMap)

Other info / logs Include any logs or source code that would be helpful to diagnose the problem. If including tracebacks, please include the full traceback. Large logs and files should be attached.

Craigacp commented 8 months ago

The reshape command you're using isn't a cast, you'll need to cast it from uint8 to float32 assuming that's what the model wants.

karllessard commented 8 months ago

+1 to @Craigacp comment, use tf.dtypes.cast to convert your Uint8 decoded image to Float.

Also it looks like you are trying to run TF eagerly but using a graph Session? You might want to look at EagerSession instead, especially that the way you are currently doing will leak, as you need close the output tensors (nor the inputs).

SaherAlSous commented 8 months ago

Thank you.