Kotlin / kotlindl

High-level Deep Learning Framework written in Kotlin and inspired by Keras
Apache License 2.0
1.46k stars 103 forks source link

Type cast error when inferencing onnx model #285

Open gasabr opened 2 years ago

gasabr commented 2 years ago

I'm trying to inference ONNX model created from lightgbm model via Kotlin DL and in every method (tried Raw ones too) i'm getting class [J cannot be cast to class [[F ([J and [[F are in module java.base of loader 'bootstrap') or in RawMethods SequenceInfo cannot be cast to class ai.onnxruntime.TensorInfo

Env:

    implementation("org.jetbrains.kotlinx:kotlin-deeplearning-api:0.3.0")
    implementation("org.jetbrains.kotlinx:kotlin-deeplearning-onnx:0.3.0-alpha-3")

Code twoTierModel.txt

package co.`fun`

import kotlin.test.*
import org.jetbrains.kotlinx.dl.api.inference.onnx.OnnxInferenceModel
import kotlin.random.Random

class ApplicationTest {
    @Test
    fun testRoot() 
{
        val onnxModel = OnnxInferenceModel.load("/tmp/twoTierModel.onnx")
        onnxModel.reshape(27)
        val features = (1..27).map { Random.nextFloat() }.toFloatArray()
        val prediction = onnxModel.predictSoftly(features, "features")
    }
}

Error is in the line 124 of the file OnnxInferenceModel.kt and it's caused by the attempt to cast List to Array, I'm not sure if the model should always return 3d Tensor or the lib should check the types.

Rename the attachment to twoTierModel.onnx to try the test at your machine

zaleslaw commented 2 years ago

Please, add the attachment @gasabr. Could you try to repeat the experiment with the latest version of the onnx-dependency, here

    implementation("org.jetbrains.kotlinx:kotlin-deeplearning-api:0.3.0")
    implementation("org.jetbrains.kotlinx:kotlin-deeplearning-onnx:0.3.0")
gasabr commented 2 years ago

twoTierModel.txt

Thanks for the response! I tried 0.3.0 got the same exception, I also have tried to use java onnxruntime and was able to inference the model with following code:

        val env = OrtEnvironment.getEnvironment()
        val session = env.createSession("/tmp/twoTierModel.onnx", OrtSession.SessionOptions())
        val features = (1..27).map { Random.nextFloat() }.toFloatArray()
        val buf = FloatBuffer.wrap(features)
        val t1 = OnnxTensor.createTensor(env, buf, longArrayOf(1, 27))
        val inputs = mapOf<String, OnnxTensor>("features" to t1)
        val result = session.run(inputs, setOf("probabilities"))[0].value as ArrayList<HashMap<Long, Float>>
        println(result)
zaleslaw commented 2 years ago

Thanks, @gasabr thanks, for the example you gave me, and I hope to fix it in the 0.4 release to cover more cases, but at this moment java onnxruntime is the best choice for you, I agree

ermolenkodev commented 2 years ago

I want to discuss a couple of things.

Onnx supports multiple output types such as tensors, sequence of numbers (or strings), sequence of maps, and a map. At first glance, it seems that it is possible to decode every type of input to appropriate Kotlin's data structure using OnnxModel metadata. But I have some doubts:

Dirty draft of decoding function ```javascript private fun decodeOnnxOutput(onnxOutput: OrtSession.Result) : Map { val keys = onnxOutput.map { it.key } return keys.associateWith { key -> { if (key !in this.session.outputInfo) throw RuntimeException() when (val info = this.session.outputInfo[key]!!.info) { is TensorInfo -> { val tensor = onnxOutput.get(key).get().value if (info.shape.size == 1) { when (info.type) { OnnxJavaType.FLOAT -> tensor as FloatArray OnnxJavaType.DOUBLE -> tensor as DoubleArray OnnxJavaType.INT8 -> tensor as ByteArray OnnxJavaType.INT16 -> tensor as ShortArray OnnxJavaType.INT32 -> tensor as IntArray OnnxJavaType.INT64 -> tensor as LongArray OnnxJavaType.UINT8 -> tensor as UByteArray else -> throw RuntimeException() } } else { when (info.type) { OnnxJavaType.FLOAT -> tensor as Array OnnxJavaType.DOUBLE -> tensor as Array OnnxJavaType.INT8 -> tensor as Array OnnxJavaType.INT16 -> tensor as Array OnnxJavaType.INT32 -> tensor as Array OnnxJavaType.INT64 -> tensor as Array OnnxJavaType.UINT8 -> tensor as Array else -> throw RuntimeException() } } } is SequenceInfo -> { val elements = onnxOutput.get(key).get().value as List if (info.sequenceOfMaps) { elements.map { when (info.mapInfo.keyType to info.mapInfo.valueType) { OnnxJavaType.INT64 to OnnxJavaType.FLOAT -> it as HashMap OnnxJavaType.STRING to OnnxJavaType.FLOAT -> it as HashMap else -> throw RuntimeException() } } } else { when (info.sequenceType) { OnnxJavaType.FLOAT -> elements as List OnnxJavaType.DOUBLE -> elements as List OnnxJavaType.INT64 -> elements as List OnnxJavaType.STRING -> elements as List else -> throw RuntimeException() } } } is MapInfo -> { val map_ = onnxOutput.get(key).get().value as OnnxMap when (info.keyType) { OnnxJavaType.INT64 -> when (info.valueType) { OnnxJavaType.FLOAT -> map_ as HashMap OnnxJavaType.DOUBLE -> map_ as HashMap OnnxJavaType.INT64 -> map_ as HashMap OnnxJavaType.STRING -> map_ as HashMap else -> throw RuntimeException() } OnnxJavaType.STRING -> when (info.valueType) { OnnxJavaType.FLOAT -> map_ as HashMap OnnxJavaType.DOUBLE -> map_ as HashMap OnnxJavaType.INT64 -> map_ as HashMap OnnxJavaType.STRING -> map_ as HashMap else -> throw RuntimeException() } else -> throw RuntimeException() } } else -> throw RuntimeException() } } } } ```
ermolenkodev commented 2 years ago

Another thing I want to discuss. For me, it seems reasonable if OnnxInferenceModel's methods predict and predictSoftly will be refactored out into more specific implementation class (like ClassificationOnnxInferenceModel).

It may be handy if OnnxInferenceModel will work with arbitrary tensors. Meanwhile, classes targeted for specific DL tasks (such as detection or segmentation) can use OnnxInferenceModel internally and format output for a specific task.