1) ai.djl.engine.rust.NDArrayTests.testExpandDim: This is candle bug. I raised an issue in candle. I will add the code back once it's fixed.
ai.djl.engine.EngineException: DriverError(CUDA_ERROR_INVALID_VALUE, "invalid argument")
at app//ai.djl.engine.rust.RustLibrary.contentEqual(Native Method)
at app//ai.djl.engine.rust.RsNDArray.contentEquals(RsNDArray.java:368)
at app//ai.djl.engine.rust.RsNDArray.equals(RsNDArray.java:1615)
at app//org.testng.Assert.areEqualImpl(Assert.java:180)
at app//org.testng.Assert.assertEqualsImpl(Assert.java:148)
at app//org.testng.Assert.assertEquals(Assert.java:132)
at app//org.testng.Assert.assertEquals(Assert.java:644)
at app//ai.djl.engine.rust.NDArrayTests.testExpandDim(NDArrayTests.java:225)
2) ai.djl.engine.rust.NDArrayTests.testToDataType: Changed to cast int32 to float16 since cast_i64_f16 is not supported.
ai.djl.engine.EngineException: Cuda(Load { cuda: DriverError(CUDA_ERROR_NOT_FOUND, "named symbol not found"), module_name: "cast_i64_f16" })
at app//ai.djl.engine.rust.RustLibrary.toDataType(Native Method)
at app//ai.djl.engine.rust.RsNDArray.toType(RsNDArray.java:175)
at app//ai.djl.engine.rust.RsNDArray.toType(RsNDArray.java:35)
at app//ai.djl.engine.rust.NDArrayTests.testToDataType(NDArrayTests.java:110)
Description
Brief description of what this PR is about
Fix NDArrayTests failure on cuda
1) ai.djl.engine.rust.NDArrayTests.testExpandDim: This is candle bug. I raised an issue in candle. I will add the code back once it's fixed.
2) ai.djl.engine.rust.NDArrayTests.testToDataType: Changed to cast int32 to float16 since cast_i64_f16 is not supported.