deepjavalibrary / djl

An Engine-Agnostic Deep Learning Framework in Java
https://djl.ai
Apache License 2.0
4.11k stars 653 forks source link

[HELP NEEDED] How to load a PyTorch Model? #2062

Closed nnahito closed 2 years ago

nnahito commented 2 years ago

I would like to know how to load and use Pytorch models on a Mac with an M1 CPU.

Want to do

I would like to use the following Model with DJL. https://github.com/UKPLab/EasyNMT

However, the following error occurs

Caused by: java.lang.IllegalStateException: No PyTorch native library matches your operating system: cpu-osx-aarch64:1.9.1

How would this be corrected? Also, how will the Model be used after the code is modified?

Gradle

dependencies {
    implementation 'ai.djl:api:0.19.0'
    implementation 'ai.djl.pytorch:pytorch-engine:0.19.0'
    implementation 'ai.djl.pytorch:pytorch-native-auto:1.9.1'
    implementation 'org.slf4j:slf4j-api:2.0.3'
}

Java Code

package org.example;

import ai.djl.*;

import java.io.IOException;
import java.nio.file.*;

public class Main {
    public static void main(String[] args) throws MalformedModelException, IOException {
        Path modelDir = Paths.get("resources/mbart50/");
        Model model = Model.newInstance("mbart50.pt");
        model.load(modelDir);
    }
}

Thank you very much.

frankfliu commented 2 years ago

A few recommendation for your project setup:

  1. We strongly suggest you to use BOM, it manages DJL version for you
  2. pytorch-native-auto is deprecated, see: https://docs.djl.ai/master/docs/development/dependency_management.html#list-of-djl-packages-published-on-maven-central
  3. Use Criteria API if possible, it provides a better abstraction then use low level Model API
  4. You can this project as template for your own application: https://github.com/deepjavalibrary/djl-demo/tree/master/development/fatjar
nnahito commented 2 years ago

@frankfliu Thanks for your reply. I immediately tried what you taught me.

However, the Model could not be loaded. Does EasyNMT not seem to exist in ZooModel and is it not available in the first place?

The errors that occurred are as follows:

Exception in thread "main" ai.djl.repository.zoo.ModelNotFoundException: No matching model with specified Input/Output type found.
    at ai.djl.repository.zoo.Criteria.loadModel(Criteria.java:180)

I am new to both Java and English, so sorry if this is difficult to read. Please let me know if you have any information.

Thank you very much.

Gradle

dependencies {
    implementation platform("ai.djl:bom:0.19.0")

    implementation "ai.djl:api"
    // Use PyTorch engine
    runtimeOnly "ai.djl.pytorch:pytorch-engine"
    // PyTorch JNI offline distribution package
    runtimeOnly "ai.djl.pytorch:pytorch-jni"
    // Uncomment one of the following native library for your target platform
    runtimeOnly "ai.djl.pytorch:pytorch-native-cpu::osx-x86_64"

    implementation 'org.slf4j:slf4j-simple:1.7.36'

    runtimeOnly "ai.djl.pytorch:pytorch-model-zoo"
}

Java Code

package org.example;

import ai.djl.*;
import ai.djl.inference.*;
import ai.djl.repository.zoo.Criteria;
import ai.djl.repository.zoo.ModelNotFoundException;
import ai.djl.repository.zoo.ZooModel;
import ai.djl.training.util.ProgressBar;
import ai.djl.translate.*;

import java.io.IOException;
import java.nio.file.*;

public class Main {
    public static void main(String[] args) throws MalformedModelException, IOException, ModelNotFoundException, TranslateException {
        Path modelDir = Paths.get("file:///Users/user_name/pytorch_test_java/src/main/resources/mbart50/mbart50.pt");

        Criteria<String, String> criteria =
                Criteria.builder()
                        .setTypes(String.class, String.class)
                        .optProgress(new ProgressBar())
                        .optFilter("target_lang", "ja")
                        .optFilter("max_new_tokens", "1000")
                        .optEngine("PyTorch")
                        .optModelPath(modelDir)
                        .build();

        ZooModel<String, String> model = criteria.loadModel();
        Predictor<String, String> predictor = model.newPredictor();
        String classifications = predictor.predict("hello world");
        System.out.println(classifications);
    }
}

EasyNMT Sample Code

from easynmt import EasyNMT
model = EasyNMT('mbart50_m2m')

while True:
    sentence = input("Enter the Japanese words you wish to translate into English: ")
    print(model.translate(sentence, target_lang='en', max_new_tokens=1000))
frankfliu commented 2 years ago

@nnahito

  1. .optFilter() is only used when you using build-in module zoo model
  2. When you loading model from local file system, usually you need supply your own transaltor: .optTransalator(new MyTranslator()), otherwise, you have to use raw input/outpu: Criteria<NDList, NDList>

Here is example for NMT model for PyTorch: https://github.com/deepjavalibrary/djl/blob/master/examples/src/main/java/ai/djl/examples/inference/NeuralMachineTranslation.java

nnahito commented 2 years ago

@frankfliu Thank you for your reply.

So the PyTorch model cannot be used as is? Do you have any samples using optTransalator?

Thank you very much.

frankfliu commented 2 years ago

@nnahito Most of models need to pre-process the input data into Tensor and post-process output Tensor into meaningful data.

DJL has a few built-in Translator to achieve common use cases. However it cannot cover all models. In NMT case, each model using its own way. We don't have a general Translator for it.

Here is an example of custom Translator: https://github.com/deepjavalibrary/djl/blob/master/examples/src/main/java/ai/djl/examples/inference/BertClassification.java#L91

nnahito commented 2 years ago

@frankfliu Thanks for your response! Thank you for tellingme about it. I will try a few things!