deepjavalibrary / djl

An Engine-Agnostic Deep Learning Framework in Java
https://djl.ai
Apache License 2.0
4.07k stars 650 forks source link

Criteria class not able to read downloaded bertqa.pt file #2217

Closed tang-john closed 1 year ago

tang-john commented 1 year ago

I used the code from https://github.com/deepjavalibrary/djl/blob/master/jupyter/pytorch/load_your_own_pytorch_bert.ipynb The code downloaded https://djl-ai.s3.amazonaws.com/mlrepo/model/nlp/question_answer/ai/djl/pytorch/bertqa/0.0.1/trace_bertqa.pt.gz and unzipped the file to build/pytorch/bertqa/bertqa.pt. I can see the file. However, the following error occurs:

Exception in thread "main" java.io.FileNotFoundException: Parameter file with prefix: bertqa not found in: /home/jtang/temp/dlj/build/pytorch/bertqa or not readable by the engine. at ai.djl.mxnet.engine.MxModel.load(MxModel.java:109) at ai.djl.repository.zoo.BaseModelLoader.loadModel(BaseModelLoader.java:161) at ai.djl.repository.zoo.Criteria.loadModel(Criteria.java:168) at org.dlj.example.Example.main(Example.java:70)

Process finished with exit code 1

The code causing the error is :

    Criteria<QAInput, String> criteria = Criteria.builder()
            .setTypes(QAInput.class, String.class)
            .optModelPath(Paths.get("build/pytorch/bertqa/")) // search in local folder
            .optTranslator(translator)
            .optProgress(new ProgressBar()).build();

The sample code is able to download the file. Why is it having issues reading it?

frankfliu commented 1 year ago

In your project, you included multiple engines in your classpath. DJL picked up MXNet in your case. You can either set the default engine to PyTorch: -Dai.djl.default_engine=PyTorch

Or you can tell Criteria which engine you want to use:

    Criteria<QAInput, String> criteria = Criteria.builder()
            .setTypes(QAInput.class, String.class)
            .optModelPath(Paths.get("build/pytorch/bertqa/")) // search in local folder
            .optTranslator(translator)
            .optEngine("PyTorch")
            .optProgress(new ProgressBar()).build();
frankfliu commented 1 year ago

Another way to tell DJL which engine to use is adding a serving.properties file in your model directory, in this case build/pytorch/bertqa folder:

engine=PyTorch