deepjavalibrary / djl

An Engine-Agnostic Deep Learning Framework in Java
https://djl.ai
Apache License 2.0
4.13k stars 655 forks source link

Error loading ONNX model, cannot be cast to ai.djl.nn.BlockFactory #1641

Closed filipsedivy closed 2 years ago

filipsedivy commented 2 years ago

Description

When trying to use hybrid inference, TensorFlow + ONNXRuntime, an error occurs when loading the model.

Error Message

To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.
Exception in thread "main" java.lang.ClassCastException: org.example.Main cannot be cast to ai.djl.nn.BlockFactory
        at ai.djl.repository.zoo.BaseModelLoader.createModel(BaseModelLoader.java:190)
        at ai.djl.repository.zoo.BaseModelLoader.loadModel(BaseModelLoader.java:149)
        at ai.djl.repository.zoo.Criteria.loadModel(Criteria.java:166)
        at org.example.Main.main(Main.java:52)

How to Reproduce?

package org.example;

import ai.djl.MalformedModelException;
import ai.djl.inference.Predictor;
import ai.djl.modality.Classifications;
import ai.djl.modality.cv.Image;
import ai.djl.modality.cv.ImageFactory;
import ai.djl.modality.cv.transform.Resize;
import ai.djl.modality.cv.transform.ToTensor;
import ai.djl.modality.cv.translator.ImageClassificationTranslator;
import ai.djl.repository.zoo.Criteria;
import ai.djl.repository.zoo.ModelNotFoundException;
import ai.djl.repository.zoo.ZooModel;
import ai.djl.translate.TranslateException;

import java.io.BufferedReader;
import java.io.IOException;
import java.io.InputStreamReader;
import java.nio.file.Paths;
import java.util.ArrayList;

public class Main {
    public static void main(String[] args) throws ModelNotFoundException, MalformedModelException, IOException, TranslateException {
        ImageClassificationTranslator translator = ImageClassificationTranslator.builder()
                .optSynset(new ArrayList<String>() {
                    {
                        add("class_1");
                        add("class_2");
                        add("class_3");
                        add("class_4");
                        add("class_5");
                        add("class_6");
                        add("class_7");
                        add("class_8");
                    }
                })
                .addTransform(new Resize(224, 224))
                .addTransform(new ToTensor())
                .build();

        BufferedReader br = new BufferedReader(new InputStreamReader(System.in));
        System.out.println("Model path:");
        String modelPath = br.readLine();

        Criteria<Image, Classifications> criteria = Criteria.builder()
                .setTypes(Image.class, Classifications.class)
                .optModelPath(Paths.get(modelPath))
                .optTranslator(translator)
                .optEngine("OnnxRuntime")
                .build();

        ZooModel<Image, Classifications> model = criteria.loadModel();
        Predictor<Image, Classifications> predictor = model.newPredictor();

        System.out.println("Image path:");
        String imagePath = br.readLine();

        Image image = ImageFactory.getInstance().fromFile(Paths.get(imagePath));
        Classifications predict = predictor.predict(image);
        System.out.println(predict.getAsString());
    }
}

Maven

<?xml version="1.0" encoding="UTF-8"?>
<project xmlns="http://maven.apache.org/POM/4.0.0"
         xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
         xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd">
    <modelVersion>4.0.0</modelVersion>

    <groupId>org.example</groupId>
    <artifactId>demo2</artifactId>
    <version>1.0-SNAPSHOT</version>

    <properties>
        <maven.compiler.source>8</maven.compiler.source>
        <maven.compiler.target>8</maven.compiler.target>
    </properties>

    <dependencies>
        <dependency>
            <groupId>ai.djl</groupId>
            <artifactId>api</artifactId>
            <version>0.16.0</version>
        </dependency>

        <dependency>
            <groupId>ai.djl.tensorflow</groupId>
            <artifactId>tensorflow-model-zoo</artifactId>
            <version>0.16.0</version>
        </dependency>

        <dependency>
            <groupId>ai.djl.tensorflow</groupId>
            <artifactId>tensorflow-engine</artifactId>
            <version>0.16.0</version>
            <scope>runtime</scope>
        </dependency>

        <dependency>
            <groupId>ai.djl.tensorflow</groupId>
            <artifactId>tensorflow-native-cpu</artifactId>
            <classifier>win-x86_64</classifier>
            <scope>runtime</scope>
            <version>2.7.0</version>
        </dependency>

        <dependency>
            <groupId>ai.djl.onnxruntime</groupId>
            <artifactId>onnxruntime-engine</artifactId>
            <version>0.16.0</version>
            <scope>runtime</scope>
        </dependency>
    </dependencies>

    <build>
        <finalName>application</finalName>

        <plugins>
            <plugin>
                <groupId>org.apache.maven.plugins</groupId>
                <artifactId>maven-shade-plugin</artifactId>
                <version>3.3.0</version>
                <configuration>
                    <transformers>
                        <transformer
                                implementation="org.apache.maven.plugins.shade.resource.ManifestResourceTransformer">
                            <mainClass>org.example.Main</mainClass>
                        </transformer>
                        <transformer
                                implementation="org.apache.maven.plugins.shade.resource.ServicesResourceTransformer"/>
                    </transformers>
                </configuration>
                <executions>
                    <execution>
                        <phase>package</phase>
                        <goals>
                            <goal>shade</goal>
                        </goals>
                        <configuration>
                            <filters>
                                <filter>
                                    <artifact>*:*</artifact>
                                    <excludes>
                                        <exclude>PamModel/*.*</exclude>  <!--  don't include files in the PamModel folder -->
                                        <exclude>META-INF/*.SF</exclude> <!-- get rid of manifests from library jars -->
                                        <exclude>META-INF/*.DSA</exclude>
                                        <exclude>META-INF/*.RSA</exclude>
                                    </excludes>
                                </filter>
                            </filters>
                        </configuration>
                    </execution>
                </executions>
            </plugin>
        </plugins>
    </build>
</project>
lanking520 commented 2 years ago

There might be jar or Java class files under the same folder. This will impact the model loading. You can try to remove it. This issue has been fixed in DJL 0.17.0 which will be coming soon. @frankfliu

filipsedivy commented 2 years ago

What do you mean by the same files? Wouldn't it be easier to downgrade?

lanking520 commented 2 years ago

What do you mean by the same files? Wouldn't it be easier to downgrade?

So there might be .class file or jar file in the folder where you load the model causing the above issue

filipsedivy commented 2 years ago

Yes, I understand now. And it really helped. I just had to plunge the JAR/EXE program into a different folder so that the model wasn't there, and now it works.

Thanks a lot!