deepjavalibrary / djl

An Engine-Agnostic Deep Learning Framework in Java
https://djl.ai
Apache License 2.0
4.07k stars 648 forks source link

Getting a cannot be cast to ai.djl.mxnet.engine.MxNDArray error #2645

Closed AbhishekBose closed 1 year ago

AbhishekBose commented 1 year ago

I am trying to implement a pooling operation on the output of an onnx model.

The code for mean pooling is given below

def pool(embeddings: NDArray,ctx: TranslatorContext): NDArray = {
    val attention_mask = ctx.getNDManager.ones(embeddings.getShape)
    val numerator = embeddings.mul(attention_mask)
    val numerator_summed = numerator.sum(new Array[Int](1))
    val denominator = attention_mask.sum(new Array[Int](1))
    val finalOutput: NDArray = numerator_summed.div(denominator)
    finalOutput.normalize(2,1)
  }

I am calling this function from the processOutput method of the translator. on calling this method I can see that both embeddings and attention_mask are OrtNDArrays of shape (1,7,768), but at the line where I want to do the element wise multiplication which is ( val numerator = embeddings.mul(attention_mask)), I am getting the following error

java.lang.ClassCastException: ai.djl.onnxruntime.engine.OrtNDArray cannot be cast to ai.djl.mxnet.engine.MxNDArray

Screenshot of array contents is attached

Screenshot 2023-06-09 at 4 11 52 PM

I can also see that the manager is same for both the arrays since I am passing the same translator ctx into the mean pooling method as well

cc: @frankfliu

frankfliu commented 1 year ago

@AbhishekBose The Hybrid engines (OnnxRuntime in this case) doesn't support NDArray operations, we have to leverage PyTorch or MXNet. We try to make it automated. But seems there is certain case it cannot. OrtNDArray.mul(MxNDArray array) should work, but MXNDArray.mul(OrtNDArray array) doesn't.

But you can easily workaround this issue on your code. There is a NDManager.from() API, instead of trying to cast OrtNDArray to MxNDArray, you can adapt the NDArray:

NDManager manager = embedding.getManager();
val numerator = embeddings.mul(manager.from(attention_mask))