Closed AbhishekBose closed 1 year ago
@AbhishekBose
The Hybrid engines (OnnxRuntime in this case) doesn't support NDArray operations, we have to leverage PyTorch or MXNet. We try to make it automated. But seems there is certain case it cannot. OrtNDArray.mul(MxNDArray array)
should work, but MXNDArray.mul(OrtNDArray array)
doesn't.
But you can easily workaround this issue on your code. There is a NDManager.from()
API, instead of trying to cast OrtNDArray
to MxNDArray
, you can adapt the NDArray:
NDManager manager = embedding.getManager();
val numerator = embeddings.mul(manager.from(attention_mask))
I am trying to implement a pooling operation on the output of an onnx model.
The code for mean pooling is given below
I am calling this function from the processOutput method of the translator. on calling this method I can see that both embeddings and attention_mask are OrtNDArrays of shape (1,7,768), but at the line where I want to do the element wise multiplication which is (
val numerator = embeddings.mul(attention_mask)
), I am getting the following errorjava.lang.ClassCastException: ai.djl.onnxruntime.engine.OrtNDArray cannot be cast to ai.djl.mxnet.engine.MxNDArray
Screenshot of array contents is attached
I can also see that the manager is same for both the arrays since I am passing the same translator ctx into the mean pooling method as well
cc: @frankfliu