deepjavalibrary / djl

An Engine-Agnostic Deep Learning Framework in Java
https://djl.ai
Apache License 2.0
4.17k stars 660 forks source link

[rust] Fix NDArrayTests failure on cuda #3319

Closed xyang16 closed 4 months ago

xyang16 commented 4 months ago

Description

Brief description of what this PR is about

Fix NDArrayTests failure on cuda

1) ai.djl.engine.rust.NDArrayTests.testExpandDim: This is candle bug. I raised an issue in candle. I will add the code back once it's fixed.

ai.djl.engine.EngineException: DriverError(CUDA_ERROR_INVALID_VALUE, "invalid argument")
    at app//ai.djl.engine.rust.RustLibrary.contentEqual(Native Method)
    at app//ai.djl.engine.rust.RsNDArray.contentEquals(RsNDArray.java:368)
    at app//ai.djl.engine.rust.RsNDArray.equals(RsNDArray.java:1615)
    at app//org.testng.Assert.areEqualImpl(Assert.java:180)
    at app//org.testng.Assert.assertEqualsImpl(Assert.java:148)
    at app//org.testng.Assert.assertEquals(Assert.java:132)
    at app//org.testng.Assert.assertEquals(Assert.java:644)
    at app//ai.djl.engine.rust.NDArrayTests.testExpandDim(NDArrayTests.java:225)

2) ai.djl.engine.rust.NDArrayTests.testToDataType: Changed to cast int32 to float16 since cast_i64_f16 is not supported.

ai.djl.engine.EngineException: Cuda(Load { cuda: DriverError(CUDA_ERROR_NOT_FOUND, "named symbol not found"), module_name: "cast_i64_f16" })
    at app//ai.djl.engine.rust.RustLibrary.toDataType(Native Method)
    at app//ai.djl.engine.rust.RsNDArray.toType(RsNDArray.java:175)
    at app//ai.djl.engine.rust.RsNDArray.toType(RsNDArray.java:35)
    at app//ai.djl.engine.rust.NDArrayTests.testToDataType(NDArrayTests.java:110)