deepjavalibrary / djl

An Engine-Agnostic Deep Learning Framework in Java
https://djl.ai
Apache License 2.0
4.07k stars 648 forks source link

Diffrent results with python #1306

Closed mengpengfei closed 2 years ago

mengpengfei commented 2 years ago

The project had make me very excited,but when i use it,some problem is too hard for me to solve. I use djl to inference onnx model,but the result of djl is diffrent from the result of python. This issue cause lower accuracy when i use djl to inference my model of onnx.

the result of python is "normalize",it is right. but the result of java is "normaiz" ,it is not right. I had meet many similar problems. the python code is:

def test_onnx():
    pic=cv_imread(r'D:\QQ截图20211018134728.png')
    pic=cv2.resize(pic, dsize=(128,32), fx=0, fy=0,interpolation=cv2.INTER_LINEAR)
    # pic=cv2.cvtColor(pic,cv2.COLOR_BGR2GRAY)
    # pic=cv2.cvtColor(pic,cv2.COLOR_GRAY2BGR)
    cv2.imwrite('./abcd.png',pic)
    # pic = cv2.imencode('.jpg',pic)[1]
    # nparr = np.fromstring(pic, np.uint8)
    # pic = cv2.imdecode(nparr,-1)
    pic=np.transpose(pic,(2,0,1))
    tensor=torch.from_numpy(pic).type(torch.float32).to(DEVICE)
    # trans=transforms.Resize((config.input_h,config.input_w),interpolation=InterpolationMode.BILINEAR)
    # tensor=trans(tensor)
    mean_std = (config.MEAN, config.STD)
    input_transform = transforms.Compose([
        # transforms.ToTensor(),
        transforms.Normalize(*mean_std)
    ])
    tensor=input_transform(tensor)

    tensor=torch.unsqueeze(tensor,0)

    # tensor=torch.rand(1,3,56,56)

    model = onnxruntime.InferenceSession(r'D:\code\cbc\demo\eng_num_checkpoint.onnx')
    batch_x = {
        'input_tensors':to_numpy(tensor)
    }
    for i in range(0,1):
        cur1=round(time.time() * 1000)
        preds = model.run(output_names=['preds'], input_feed=batch_x)
        cur2=round(time.time() * 1000)
        p=cur2-cur1
        print("识别结果,耗时:%d"%(p))
    preds=torch.tensor(preds[0])
    preds_size = torch.IntTensor([16]*1)
    res = converter.decode( preds.data, preds_size.data, raw=False)
    res = res.strip()
    print(res)

the java code:

public static void main(String[] args) throws IOException, MalformedModelException, ModelNotFoundException, TranslateException {
        try (NDManager manager = NDManager.newBaseManager(Device.cpu(), "PyTorch")) {
            Criteria<NDList, NDList> criteria =
                    Criteria.builder()
                            .setTypes(NDList.class, NDList.class)
                            .optEngine("OnnxRuntime").optModelPath(Paths.get("D:\\code\\cbc\\demo\\eng_num_checkpoint.onnx"))
                            .optProgress(new ProgressBar())
                            .optDevice(Device.cpu())
                            .build();
            ZooModel<NDList, NDList> model = ModelZoo.loadModel(criteria);
            Predictor<NDList, NDList> predictor = model.newPredictor();
            Path path = Paths.get("D:\\QQ截图20211018134728.png");
            Mat mat = byte2Mat(Files.readAllBytes(path));
            Mat mat1 = new Mat();
            Imgproc.resize(mat, mat1, new Size(128, 32),0,0,Imgproc.INTER_LINEAR);

            imwrite("./random.png",mat1);
            byte[] bytes = mat2Byte(mat1, ".png");

            NDArray imageArray =
                    MyBufferedImageFactory.getInstance()
                            .fromInputStream(new ByteArrayInputStream(bytes))
                            .toNDArray(manager);

            NDArray ndArray = imageArray
                    .expandDims(0)
                    .transpose(new int[]{0,3,1,2})
                    .toType(DataType.FLOAT32, true);
            NDList ndArrays = new NDList(ndArray);

            Pipeline pipeline = new Pipeline();
            float[] MEAN = new float[]{0.485f, 0.456f, 0.406f};
            float[] STD = new float[]{0.229f, 0.224f, 0.225f};

            pipeline.add(new Normalize(MEAN,STD));
            NDList transform = pipeline.transform(ndArrays);

            NDList predict = null;
            for (int i=0;i<1;i++) {
                long currentTimeMillis = System.currentTimeMillis();
                predict=predictor.predict(transform);
                long currentTimeMillis1 = System.currentTimeMillis();
                System.out.println("consume time:"+(currentTimeMillis1-currentTimeMillis));
            }
            System.out.println(predict);
            StringBuilder sb = new StringBuilder();
            for (NDArray arr1 : predict) {
                long[] ints = arr1.toLongArray();
                for (int i = 0; i < ints.length; i++) {
                    if (ints[i] != 0 && !(i > 0 && ints[i - 1] == ints[i])) {
                        sb.append(alphabet_list[(int) ints[i]]);
                    }
                }
            }
            System.out.println(sb.toString());
//                Thread.currentThread().sleep(10000000);
//            }catch (Exception e){
//                e.printStackTrace();
//            }
        }
    }

please help me! thanks very much.Wait for your answer.

frankfliu commented 2 years ago

@mengpengfei The forward pass should be identical between java and python. We share the same underlying C++ API.

If the results are different, most likely is because of preprocessing (in most case is resizing). Can you share your model, so we can test your code?

mengpengfei commented 2 years ago

I can send my model to your email,but please do not public the model.Is it ok? if it is ok,please give me your email address.

mengpengfei commented 2 years ago

Follwing is the test picture:

QQ截图20211018134728

mengpengfei commented 2 years ago

@frankfliu can you help me? thank you any way.

frankfliu commented 2 years ago

@mengpengfei we will try to compare your image processing first. I will ask for your model if needed. Most likely the difference comes from resize().

It's looks like you didn't resize on python, but resized in java. Is that intentional?

mengpengfei commented 2 years ago

@frankfliu No,It not intentional,I had resized the data on python too.The image processing operator is same. image

mengpengfei commented 2 years ago

image

Above picture is theversion that I used.

frankfliu commented 2 years ago

@mengpengfei I tried opencv in java, I'm able to generate identical result between DJL and python code.

I noticed your image processing is wrong in both python and java code. Here is the code I tested:

Python:

    pic = cv2.imread('test.png')
    pic = cv2.resize(pic, dsize=(128, 32), fx=0, fy=0, interpolation=cv2.INTER_LINEAR)
    pic = np.asarray(pic)
    mean_std = (MEAN, STD)
    input_transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize(*mean_std)
    ])
    tensor = input_transform(pic)
    tensor = torch.unsqueeze(tensor, 0)

    with open("python.bin", "wb") as f:
        f.write(tensor.detach().numpy().tobytes())

Java code:

OpenCV.loadShared();

float[] MEAN = new float[]{0.485f, 0.456f, 0.406f};
float[] STD = new float[]{0.229f, 0.224f, 0.225f};

try (NDManager manager = NDManager.newBaseManager()) {
    String file = "test.png";
    Mat mat = Imgcodecs.imread(file);
    Mat mat1 = new Mat();
    Imgproc.resize(mat, mat1, new Size(128, 32), 0, 0, Imgproc.INTER_LINEAR);
    byte[] buf = new byte[128 * 32 * mat1.channels()];
    mat1.get(0, 0, buf);

    Shape shape = new Shape(32, 128, mat1.channels());
    NDArray imageArray = manager.create(ByteBuffer.wrap(buf), shape, DataType.UINT8);

//    NDArray ndArray = imageArray
//            .expandDims(0)
//            .div(255.0) // you missed this
//            .transpose(0, 3, 1, 2)
//            .toType(DataType.FLOAT32, true);
//
//    Normalize normalize = new Normalize(MEAN, STD);
//    NDArray ret = normalize.transform(ndArray);

    Pipeline pipeline = new Pipeline();
    pipeline.add(new ToTensor())
            .add(new Normalize(MEAN, STD))
            .add(a -> a.expandDims(0));
    NDList list = pipeline.transform(new NDList(imageArray));

    buf = list.head().toByteArray();
    try (OutputStream os = Files.newOutputStream(Paths.get("java_cv.bin"))) {
        os.write(buf);
    }
}
frankfliu commented 2 years ago

@mengpengfei by the way, do you have to use cv2 to resize? Can you use transforms.Resize()? DJL's Resize() is identical to transforms.Resize().

mengpengfei commented 2 years ago
        When i use follwing code,the problem is solved,thank you very much!
       byte[] bytes = new byte[32 * 128 * 3];
        int i1 = mat1.get(0, 0, bytes);
        Shape shape = new Shape(32, 128, 3);
        NDArray imageArray = manager.create(ByteBuffer.wrap(bytes), shape, DataType.UINT8);