deepjavalibrary / djl

An Engine-Agnostic Deep Learning Framework in Java
https://djl.ai
Apache License 2.0
4.17k stars 660 forks source link

pytroch模型推断性能很低 #951

Closed zhangyunGit closed 2 years ago

zhangyunGit commented 3 years ago

Description

使用djl加载torchscript转换的pt模型,发现推断性能很低,比直接使用python加载模型推断的方式下降约8倍。

为了确认是我pt模型的问题,还是djl框架的问题,我使用c++加载我的pt模型进行推断比较,以下只统计了forward的时间,代码片段如下: for(int i=1;i<10;i=i+1){ start = clock();//1计时开始 output1 = net.forward({input1,input2,input3,input4}); end = clock();//1计时开始 std::cout << "The run time" << i << " is: " <<(double)(end - start) / CLOCKS_PER_SEC << "s" << std::endl; std::cout << "output1: " << output1 << std::endl; } 结果如下:

torchptcplus

如上图,第1,2次推断性能都很低,分别为92ms,72ms,从第3次开始性能下降到11ms-14ms

我从pytorch的github上了解到: We compile the graph for each set of different tensor dimensions that are run and then cache it, so it's likely the first run will be slower.详情如下: https://github.com/pytorch/pytorch/issues/19106

由于djl每次推断的时间与c++加载后第一次推断的时间差不多,都是90ms左右。所以我怀疑djl使用java的jni(or jna?)加载torchlib dll的方式,使用每次都需要重新加载再推断,从而没有办法很好的利用上文说的tensor cache,所以每次都很慢。

想请教一下,当前是否有相关的优化配置可以使用,从而解决上述性能问题? 多谢。

Environment Info

Please run the command ./gradlew debugEnv from the root directory of DJL (if necessary, clone DJL first). It will output information about your system, environment, and installation that can help us debug your issue. Paste the output of the command below:

zhangyun@zhangyundeMacBook-Pro djl % ./gradlew debugEnv
/Users/zhangyun/IdeaProjects/djl
/Users/zhangyun/IdeaProjects/djl/gradle/wrapper/gradle-wrapper.jar
java
-Xdock:name=Gradle -Xdock:icon=/Users/zhangyun/IdeaProjects/djl/media/gradle.icns -Dorg.gradle.appname=gradlew -classpath /Users/zhangyun/IdeaProjects/djl/gradle/wrapper/gradle-wrapper.jar org.gradle.wrapper.GradleWrapperMain debugEnv
Starting a Gradle Daemon (subsequent builds will be faster)

> Task :integration:debugEnv
[DEBUG] - Found EngineProvider: MXNet
[DEBUG] - Found EngineProvider: PyTorch
[DEBUG] - Found EngineProvider: TensorFlow
[DEBUG] - Found default engine: MXNet
----------- System Properties -----------
sun.cpu.isalist:
sun.io.unicode.encoding: UnicodeBig
sun.cpu.endian: little
java.vendor.url.bug: http://bugreport.sun.com/bugreport/
file.separator: /
java.vendor: Oracle Corporation
sun.boot.class.path: /Library/Java/JavaVirtualMachines/jdk1.8.0_181.jdk/Contents/Home/jre/lib/resources.jar:/Library/Java/JavaVirtualMachines/jdk1.8.0_181.jdk/Contents/Home/jre/lib/rt.jar:/Library/Java/JavaVirtualMachines/jdk1.8.0_181.jdk/Contents/Home/jre/lib/sunrsasign.jar:/Library/Java/JavaVirtualMachines/jdk1.8.0_181.jdk/Contents/Home/jre/lib/jsse.jar:/Library/Java/JavaVirtualMachines/jdk1.8.0_181.jdk/Contents/Home/jre/lib/jce.jar:/Library/Java/JavaVirtualMachines/jdk1.8.0_181.jdk/Contents/Home/jre/lib/charsets.jar:/Library/Java/JavaVirtualMachines/jdk1.8.0_181.jdk/Contents/Home/jre/lib/jfr.jar:/Library/Java/JavaVirtualMachines/jdk1.8.0_181.jdk/Contents/Home/jre/classes
java.ext.dirs: /Users/zhangyun/Library/Java/Extensions:/Library/Java/JavaVirtualMachines/jdk1.8.0_181.jdk/Contents/Home/jre/lib/ext:/Library/Java/Extensions:/Network/Library/Java/Extensions:/System/Library/Java/Extensions:/usr/lib/java
java.version: 1.8.0_181
java.vm.info: mixed mode
awt.toolkit: sun.lwawt.macosx.LWCToolkit
user.language: zh
java.specification.vendor: Oracle Corporation
sun.java.command: ai.djl.integration.util.DebugEnvironment
java.home: /Library/Java/JavaVirtualMachines/jdk1.8.0_181.jdk/Contents/Home/jre
sun.arch.data.model: 64
java.vm.specification.version: 1.8
java.class.path: /Users/zhangyun/IdeaProjects/djl/integration/build/classes/java/main:/Users/zhangyun/IdeaProjects/djl/integration/build/resources/main:/Users/zhangyun/.gradle/caches/modules-2/files-2.1/commons-cli/commons-cli/1.4/c51c00206bb913cd8612b24abd9fa98ae89719b1/commons-cli-1.4.jar:/Users/zhangyun/.gradle/caches/modules-2/files-2.1/org.apache.logging.log4j/log4j-slf4j-impl/2.13.3/7cca27a921a18645139cf651c04b83b1a19cfd76/log4j-slf4j-impl-2.13.3.jar:/Users/zhangyun/IdeaProjects/djl/basicdataset/build/libs/basicdataset-0.12.0-SNAPSHOT.jar:/Users/zhangyun/IdeaProjects/djl/model-zoo/build/libs/model-zoo-0.12.0-SNAPSHOT.jar:/Users/zhangyun/IdeaProjects/djl/testing/build/libs/testing-0.12.0-SNAPSHOT.jar:/Users/zhangyun/.gradle/caches/modules-2/files-2.1/org.testng/testng/7.1.0/b0bcea778fb2899aeb4014c558babea8833d180a/testng-7.1.0.jar:/Users/zhangyun/IdeaProjects/djl/mxnet/mxnet-model-zoo/build/libs/mxnet-model-zoo-0.12.0-SNAPSHOT.jar:/Users/zhangyun/.gradle/caches/modules-2/files-2.1/ai.djl.mxnet/mxnet-native-auto/1.8.0/e32265c03e27e1fb18c9c0904733b00f9acffaee/mxnet-native-auto-1.8.0.jar:/Users/zhangyun/IdeaProjects/djl/pytorch/pytorch-model-zoo/build/libs/pytorch-model-zoo-0.12.0-SNAPSHOT.jar:/Users/zhangyun/.gradle/caches/modules-2/files-2.1/ai.djl.pytorch/pytorch-native-auto/1.8.1/3cbb59c8b21c24cb368d296f6c4c6ef069d4d9b/pytorch-native-auto-1.8.1.jar:/Users/zhangyun/IdeaProjects/djl/tensorflow/tensorflow-model-zoo/build/libs/tensorflow-model-zoo-0.12.0-SNAPSHOT.jar:/Users/zhangyun/.gradle/caches/modules-2/files-2.1/ai.djl.tensorflow/tensorflow-native-auto/2.4.1/20b8c7a4e6d451e782d15dd30cebd4df0ad86c74/tensorflow-native-auto-2.4.1.jar:/Users/zhangyun/IdeaProjects/djl/mxnet/mxnet-engine/build/libs/mxnet-engine-0.12.0-SNAPSHOT.jar:/Users/zhangyun/IdeaProjects/djl/pytorch/pytorch-engine/build/libs/pytorch-engine-0.12.0-SNAPSHOT.jar:/Users/zhangyun/IdeaProjects/djl/tensorflow/tensorflow-engine/build/libs/tensorflow-engine-0.12.0-SNAPSHOT.jar:/Users/zhangyun/IdeaProjects/djl/api/build/libs/api-0.12.0-SNAPSHOT.jar:/Users/zhangyun/.gradle/caches/modules-2/files-2.1/org.slf4j/slf4j-api/1.7.30/b5a4b6d16ab13e34a88fae84c35cd5d68cac922c/slf4j-api-1.7.30.jar:/Users/zhangyun/.gradle/caches/modules-2/files-2.1/org.apache.logging.log4j/log4j-core/2.13.3/4e857439fc4fe974d212adaaaa3b118b8b50e3ec/log4j-core-2.13.3.jar:/Users/zhangyun/.gradle/caches/modules-2/files-2.1/org.apache.logging.log4j/log4j-api/2.13.3/ec1508160b93d274b1add34419b897bae84c6ca9/log4j-api-2.13.3.jar:/Users/zhangyun/.gradle/caches/modules-2/files-2.1/org.apache.commons/commons-csv/1.8/37ca9a9aa2d4be2599e55506a6d3170dd7a3df4/commons-csv-1.8.jar:/Users/zhangyun/.gradle/caches/modules-2/files-2.1/com.beust/jcommander/1.72/6375e521c1e11d6563d4f25a07ce124ccf8cd171/jcommander-1.72.jar:/Users/zhangyun/.gradle/caches/modules-2/files-2.1/com.google.inject/guice/4.1.0/faf9ee8ac09eafd1128091426dd367a8c0085d55/guice-4.1.0-no_aop.jar:/Users/zhangyun/.gradle/caches/modules-2/files-2.1/org.yaml/snakeyaml/1.21/18775fdda48574784f40b47bf478ab0593f92e4d/snakeyaml-1.21.jar:/Users/zhangyun/.gradle/caches/modules-2/files-2.1/com.google.code.gson/gson/2.8.6/9180733b7df8542621dc12e21e87557e8c99b8cb/gson-2.8.6.jar:/Users/zhangyun/.gradle/caches/modules-2/files-2.1/net.java.dev.jna/jna/5.3.0/4654d1da02e4173ba7b64f7166378847db55448a/jna-5.3.0.jar:/Users/zhangyun/.gradle/caches/modules-2/files-2.1/org.apache.commons/commons-compress/1.20/b8df472b31e1f17c232d2ad78ceb1c84e00c641b/commons-compress-1.20.jar:/Users/zhangyun/.gradle/caches/modules-2/files-2.1/javax.inject/javax.inject/1/6975da39a7040257bd51d21a231b76c915872d38/javax.inject-1.jar:/Users/zhangyun/.gradle/caches/modules-2/files-2.1/aopalliance/aopalliance/1.0/235ba8b489512805ac13a8f9ea77a1ca5ebe3e8/aopalliance-1.0.jar:/Users/zhangyun/.gradle/caches/modules-2/files-2.1/com.google.guava/guava/19.0/6ce200f6b23222af3d8abb6b6459e6c44f4bb0e9/guava-19.0.jar:/Users/zhangyun/IdeaProjects/djl/tensorflow/tensorflow-api/build/libs/tensorflow-api-0.12.0-SNAPSHOT.jar:/Users/zhangyun/.gradle/caches/modules-2/files-2.1/org.bytedeco/javacpp/1.5.5/92e1c31aaed15a3dc12008859a37ced45fa0b730/javacpp-1.5.5.jar:/Users/zhangyun/.gradle/caches/modules-2/files-2.1/org.tensorflow/tensorflow-core-api/0.3.1/954f292e85f4d2a587ede1b2e1a525e74ef96c97/tensorflow-core-api-0.3.1.jar:/Users/zhangyun/.gradle/caches/modules-2/files-2.1/com.google.protobuf/protobuf-java/3.8.0/b5f93103d113540bb848fe9ce4e6819b1f39ee49/protobuf-java-3.8.0.jar:/Users/zhangyun/.gradle/caches/modules-2/files-2.1/org.tensorflow/ndarray/0.3.1/3cdb825411a9de908cc3dac740f18628d6512260/ndarray-0.3.1.jar
user.name: zhangyun
ai.djl.logging.level: debug
file.encoding: UTF-8
java.specification.version: 1.8
java.awt.printerjob: sun.lwawt.macosx.CPrinterJob
user.timezone: Asia/Shanghai
user.home: /Users/zhangyun
library.jansi.path: /Users/zhangyun/.gradle/native/jansi/1.18/osx
os.version: 10.15.7
sun.management.compiler: HotSpot 64-Bit Tiered Compilers
java.specification.name: Java Platform API Specification
java.class.version: 52.0
org.gradle.internal.http.connectionTimeout: 60000
java.library.path: /Users/zhangyun/Library/Java/Extensions:/Library/Java/Extensions:/Network/Library/Java/Extensions:/System/Library/Java/Extensions:/usr/lib/java:.
org.gradle.internal.publish.checksums.insecure: true
sun.jnu.encoding: UTF-8
os.name: Mac OS X
user.variant:
java.vm.specification.vendor: Oracle Corporation
org.gradle.appname: gradlew
java.io.tmpdir: /var/folders/v8/04qt22zx6fxg7sc0clv73ltm0000gn/T/
line.separator:

java.endorsed.dirs: /Library/Java/JavaVirtualMachines/jdk1.8.0_181.jdk/Contents/Home/jre/lib/endorsed
os.arch: x86_64
java.awt.graphicsenv: sun.awt.CGraphicsEnvironment
java.runtime.version: 1.8.0_181-b13
java.vm.specification.name: Java Virtual Machine Specification
user.dir: /Users/zhangyun/IdeaProjects/djl/integration
org.gradle.internal.http.socketTimeout: 120000
user.country: CN
sun.java.launcher: SUN_STANDARD
sun.os.patch.level: unknown
java.vm.name: Java HotSpot(TM) 64-Bit Server VM
file.encoding.pkg: sun.io
path.separator: :
java.vm.vendor: Oracle Corporation
java.vendor.url: http://java.oracle.com/
gopherProxySet: false
sun.boot.library.path: /Library/Java/JavaVirtualMachines/jdk1.8.0_181.jdk/Contents/Home/jre/lib:/Library/Java/JavaVirtualMachines/jdk1.8.0_181.jdk/Contents/Home/jre/lib
java.vm.version: 25.181-b13
java.runtime.name: Java(TM) SE Runtime Environment

--------- Environment Variables ---------
PATH: /usr/local/bin:/usr/local/sbin:/usr/local/bin:/usr/bin:/bin:/usr/sbin:/sbin
JAVA_ARCH: x86_64
LC_TERMINAL: iTerm2
MANPATH: /usr/local/share/man::
TERM: xterm-256color
HOMEBREW_PREFIX: /usr/local
LANG: zh_CN.UTF-8
JAVA_MAIN_CLASS_25112: ai.djl.integration.util.DebugEnvironment
APP_ICON_25100: /Users/zhangyun/IdeaProjects/djl/media/gradle.icns
ITERM_SESSION_ID: w0t0p0:190FF5FF-B349-4339-88D6-F5526A9E896E
COLORTERM: truecolor
LOGNAME: zhangyun
HOMEBREW_REPOSITORY: /usr/local/Homebrew
XPC_SERVICE_NAME: 0
PWD: /Users/zhangyun/IdeaProjects/djl
TERM_PROGRAM_VERSION: 3.4.5
INFOPATH: /usr/local/share/info:
LC_TERMINAL_VERSION: 3.4.5
SHELL: /bin/zsh
TERM_PROGRAM: iTerm.app
HOMEBREW_BOTTLE_DOMAIN: https://mirrors.ustc.edu.cn/homebrew-bottles
HOMEBREW_CELLAR: /usr/local/Cellar
OLDPWD: /Users/zhangyun/IdeaProjects/djl
USER: zhangyun
JAVA_MAIN_CLASS_25100: org.gradle.wrapper.GradleWrapperMain
ITERM_PROFILE: Default
TMPDIR: /var/folders/v8/04qt22zx6fxg7sc0clv73ltm0000gn/T/
SSH_AUTH_SOCK: /private/tmp/com.apple.launchd.y0G8jwr6p8/Listeners
XPC_FLAGS: 0x0
TERM_SESSION_ID: w0t0p0:190FF5FF-B349-4339-88D6-F5526A9E896E
__CF_USER_TEXT_ENCODING: 0x1F5:0x19:0x34
com.apple.java.jvmTask: CommandLine
APP_NAME_25100: Gradle
COLORFGBG: 7;0
HOME: /Users/zhangyun
SHLVL: 1

-------------- Directories --------------
temp directory: /var/folders/v8/04qt22zx6fxg7sc0clv73ltm0000gn/T
DJL cache directory: /Users/zhangyun/.djl.ai
Engine cache directory: /Users/zhangyun/.djl.ai

------------------ CUDA -----------------
[DEBUG] - cudart library not found.
[DEBUG] - Using cache dir: /Users/zhangyun/.djl.ai/mxnet
[DEBUG] - Loading mxnet library from: /Users/zhangyun/.djl.ai/mxnet/1.8.0-mkl-osx-x86_64/libmxnet.dylib
GPU Count: 0
Default Device: cpu()

----------------- Engines ---------------
Default Engine: MXNet
PyTorch: 2
[DEBUG] - Using cache dir: /Users/zhangyun/.djl.ai/pytorch
[INFO ] - Downloading https://publish.djl.ai/pytorch-1.8.1/cpu/osx/native/lib/libtorch.dylib.gz ...
[INFO ] - Downloading https://publish.djl.ai/pytorch-1.8.1/cpu/osx/native/lib/libiomp5.dylib.gz ...
[INFO ] - Downloading https://publish.djl.ai/pytorch-1.8.1/cpu/osx/native/lib/libc10.dylib.gz ...
[INFO ] - Downloading https://publish.djl.ai/pytorch-1.8.1/cpu/osx/native/lib/libtorch_cpu.dylib.gz ...
[DEBUG] - Loading pytorch library from: /Users/zhangyun/.djl.ai/pytorch/1.8.1-cpu-osx-x86_64/0.12.0-SNAPSHOT-cpu-libdjl_torch.dylib
[INFO ] - Number of inter-op threads is 4
[INFO ] - Number of intra-op threads is 4
MXNet: 0
TensorFlow: 3
[DEBUG] - Using cache dir: /Users/zhangyun/.djl.ai/tensorflow
[INFO ] - Downloading https://publish.djl.ai/tensorflow-2.4.1/osx/cpu/THIRD_PARTY_TF_JNI_LICENSES.gz ...
[INFO ] - Downloading https://publish.djl.ai/tensorflow-2.4.1/osx/cpu/libtensorflow_cc.2.dylib.gz ...
[INFO ] - Downloading https://publish.djl.ai/tensorflow-2.4.1/osx/cpu/LICENSE.gz ...
[INFO ] - Downloading https://publish.djl.ai/tensorflow-2.4.1/osx/cpu/libtensorflow_framework.2.dylib.gz ...
[INFO ] - Downloading https://publish.djl.ai/tensorflow-2.4.1/osx/cpu/libjnitensorflow.dylib.gz ...
[DEBUG] - Loading TensorFlow library from: /Users/zhangyun/.djl.ai/tensorflow/2.4.1-cpu-osx-x86_64/libjnitensorflow.dylib
Warning: Could not load Loader: java.lang.UnsatisfiedLinkError: no jnijavacpp in java.library.path
2021-05-11 11:13:09.281680: I external/org_tensorflow/tensorflow/core/platform/cpu_feature_guard.cc:142] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations:  AVX2 FMA
To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.

--------------- Hardware --------------
Available processors (cores): 8
Byte Order: LITTLE_ENDIAN
Free memory (bytes): 458468648
Maximum memory (bytes): 7635730432
Total memory available to JVM (bytes): 495976448
Heap committed: 495976448
Heap nonCommitted: 31457280
GCC:
Apple clang version 12.0.0 (clang-1200.0.32.29)
Target: x86_64-apple-darwin19.6.0
Thread model: posix
InstalledDir: /Library/Developer/CommandLineTools/usr/bin

Deprecated Gradle features were used in this build, making it incompatible with Gradle 7.0.
Use '--warning-mode all' to show the individual deprecation warnings.
See https://docs.gradle.org/6.7.1/userguide/command_line_interface.html#sec:command_line_warnings

BUILD SUCCESSFUL in 41s
39 actionable tasks: 1 executed, 38 up-to-date
frankfliu commented 3 years ago

Our benchmark shows DJL pytorch has similar performance as running python. You can see: http://docs.djl.ai/master/docs/development/benchmark_with_djl.html for how we test the performance.

Can you share your code how you test the performance using DJL?

zhangyunGit commented 3 years ago

@frankfliu 我的模型是nlp领域的文本匹配esim模型,使用的springboot配合djl搭建的web框架,我单独计算djl的batchPredict时间,多次http请求结果都是90ms左右。 时间统计的代码片段如下 try { long start = new Date().getTime(); List<Classifications> result = predictor.batchPredict(inputs); long end = new Date().getTime(); System.out.println((end-start)); return result; } catch (TranslateException e) { e.printStackTrace(); } `

整体代码如下:

package com.dfx.test.djlserv;

import ai.djl.Application; import ai.djl.MalformedModelException; import ai.djl.Model; import ai.djl.inference.Predictor; import ai.djl.modality.Classifications; import ai.djl.modality.nlp.SimpleVocabulary; import ai.djl.ndarray.NDArray; import ai.djl.ndarray.NDList; import ai.djl.ndarray.NDManager; import ai.djl.repository.Artifact; import ai.djl.repository.zoo.Criteria; import ai.djl.repository.zoo.ModelNotFoundException; import ai.djl.repository.zoo.ModelZoo; import ai.djl.repository.zoo.ZooModel; import ai.djl.training.util.ProgressBar; import ai.djl.translate.Batchifier; import ai.djl.translate.TranslateException; import ai.djl.translate.Translator; import ai.djl.translate.TranslatorContext; import com.google.common.primitives.Floats; import org.springframework.stereotype.Component; import org.springframework.util.ResourceUtils;

import java.io.FileNotFoundException; import java.io.IOException; import java.net.MalformedURLException; import java.nio.file.Path; import java.nio.file.Paths; import java.util.*; import java.util.stream.Collectors;

@Component public class DJLEsimTest {

private String vocabularyPath;
private String modelPath;
private ZooModel model;
Predictor<QueryDocInput,Classifications> predictor;

public DJLEsimTest(){
    init();
}

private void init(){
    try {
        this.vocabularyPath = ResourceUtils.getFile("classpath:vocab_large.txt").getPath();
        this.modelPath = ResourceUtils.getFile("classpath:models").getPath();
        this.model = getModel();
        this.predictor = this.getPredictor();
    } catch (FileNotFoundException e) {
        e.printStackTrace();
    }
}

public class EsimTranslator implements Translator<QueryDocInput,Classifications>{

    private Vocab vocab;

    @Override
    public void prepare(NDManager manager, Model model) throws IOException {
        vocab = new Vocab(vocabularyPath);
    }

    @Override
    public Batchifier getBatchifier() {
        return Batchifier.STACK;
    }

    @Override
    public NDList processInput(TranslatorContext translatorContext, QueryDocInput queryDocInput) throws Exception {

        String query = queryDocInput.getQuery();
        String doc = queryDocInput.getDoc();

        long[] queryIds = Arrays.stream(this.vocab.text2Ids(query)).mapToLong(Integer::longValue).toArray();
        long[] docIds = Arrays.stream(this.vocab.text2Ids(doc)).mapToLong(Integer::longValue).toArray();

        long queryLen = Vocab.seqLen.longValue();
        long docLen = vocab.seqLen.longValue();

        NDManager manager = translatorContext.getNDManager();
        NDArray  queryArray = manager.create(queryIds);
        NDArray docArray = manager.create(docIds);
        NDArray queryLenArray = manager.create(queryLen);
        NDArray docLenArray = manager.create(docLen);

        return new NDList(queryArray,queryLenArray,docArray,docLenArray);
    }

    @Override
    public Classifications processOutput(TranslatorContext translatorContext, NDList ndList) throws Exception {

        NDArray probArray = ndList.get(1);

        List<String> classNames = Arrays.asList("0","1");
        List<Double> probabilities = Floats.asList(probArray.toFloatArray()).stream().mapToDouble(Float::doubleValue).boxed().collect(Collectors.toList());

        Classifications classifications = new Classifications(classNames,probabilities);
        return classifications;
    }

}

private ZooModel getModel(){
    EsimTranslator translator = new EsimTranslator();
    //System.setProperty("ai.djl.pytorch:pytorch-model-zoo", "build/models/trace_esim_model");

    try {
        Criteria<QueryDocInput, Classifications> criteria = Criteria.builder()
                .setTypes(QueryDocInput.class, Classifications.class)
                //.optModelPath(Paths.get("build/models/trace_esim_model/")) // search in local folder
                .optModelUrls(modelPath)
                .optModelName("trace_esim_model")
                .optTranslator(translator)
                .optProgress(new ProgressBar()).build();
        ZooModel model = ModelZoo.loadModel(criteria);
        return model;
    } catch (MalformedURLException e) {
        e.printStackTrace();
    } catch (MalformedModelException e) {
        e.printStackTrace();
    } catch (ModelNotFoundException e) {
        e.printStackTrace();
    } catch (IOException e) {
        e.printStackTrace();
    }
    return null;
}

public Classifications predict(String query,String doc){
    QueryDocInput input = new QueryDocInput(query,doc);

    try {
        return predictor.predict(input);
    } catch (TranslateException e) {
        e.printStackTrace();
    }
    return null;
}

public List<Classifications> batchPredict(List<QueryDocInput> inputs){
   // padInputs(inputs);
    try {
        long start = new Date().getTime();
        List<Classifications> result = predictor.batchPredict(inputs);
        long end = new Date().getTime();
        System.out.println((end-start));
        return result;
    } catch (TranslateException e) {
        e.printStackTrace();
    }
    return null;
}

private void padInputs(List<QueryDocInput> inputs){
    if(inputs.size()<50){
        String padUnk = "[UNK][UNK][UNK][UNK][UNK][UNK][UNK][UNK][UNK][UNK]";
        QueryDocInput padInput = new QueryDocInput(padUnk,padUnk);
        while (inputs.size()<50){
            inputs.add(padInput);
        }
    }
}

private Predictor<QueryDocInput,Classifications>  getPredictor(){
    EsimTranslator translator = new EsimTranslator();
    try {
        Predictor<QueryDocInput,Classifications> predictor = this.model.newPredictor(translator);
        return predictor;
    } catch (Exception e) {
        e.printStackTrace();
    }
    return null;
}

} `

zhangyunGit commented 3 years ago

我在我本地试了3次benchmark,第一次比较慢,后面2次比较快

zhangyun@zhangyundeMacBook-Pro djl % ./gradlew benchmark -Dai.djl.default_engine=PyTorch --args="-c 1 -s 1,3,300,300 -n ai.djl.pytorch:ssd -r {'size':'300','backbone':'resnet50'}"

/Users/zhangyun/IdeaProjects/djl
/Users/zhangyun/IdeaProjects/djl/gradle/wrapper/gradle-wrapper.jar
java
-Xdock:name=Gradle -Xdock:icon=/Users/zhangyun/IdeaProjects/djl/media/gradle.icns -Dorg.gradle.appname=gradlew -classpath /Users/zhangyun/IdeaProjects/djl/gradle/wrapper/gradle-wrapper.jar org.gradle.wrapper.GradleWrapperMain benchmark -Dai.djl.default_engine=PyTorch --args=-c 1 -s 1,3,300,300 -n ai.djl.pytorch:ssd -r {'size':'300','backbone':'resnet50'}

> Task :examples:benchmark
[INFO ] - Number of inter-op threads is 4
[INFO ] - Number of intra-op threads is 4
[INFO ] - Load library 1.8.1 in 420.746 ms.
[INFO ] - Running Benchmark on: cpu().
Downloading: 100% |████████████████████████████████████████|
Loading:     100% |████████████████████████████████████████|
[INFO ] - Model ssd_300_resnet50 loaded in: 6480.265 ms.
[INFO ] - Inference result: [0.51217073, -0.12814945, 0.06952064 ...]
[INFO ] - Throughput: 0.14, completed 1 iteration in 7225 ms.
[INFO ] - Model loading time: 6480.265 ms.

Deprecated Gradle features were used in this build, making it incompatible with Gradle 7.0.
Use '--warning-mode all' to show the individual deprecation warnings.
See https://docs.gradle.org/6.7.1/userguide/command_line_interface.html#sec:command_line_warnings

BUILD SUCCESSFUL in 9s
18 actionable tasks: 1 executed, 17 up-to-date
zhangyun@zhangyundeMacBook-Pro djl % ./gradlew benchmark -Dai.djl.default_engine=PyTorch --args="-c 1 -s 1,3,300,300 -n ai.djl.pytorch:ssd -r {'size':'300','backbone':'resnet50'}"

/Users/zhangyun/IdeaProjects/djl
/Users/zhangyun/IdeaProjects/djl/gradle/wrapper/gradle-wrapper.jar
java
-Xdock:name=Gradle -Xdock:icon=/Users/zhangyun/IdeaProjects/djl/media/gradle.icns -Dorg.gradle.appname=gradlew -classpath /Users/zhangyun/IdeaProjects/djl/gradle/wrapper/gradle-wrapper.jar org.gradle.wrapper.GradleWrapperMain benchmark -Dai.djl.default_engine=PyTorch --args=-c 1 -s 1,3,300,300 -n ai.djl.pytorch:ssd -r {'size':'300','backbone':'resnet50'}

> Task :examples:benchmark
[INFO ] - Number of inter-op threads is 4
[INFO ] - Number of intra-op threads is 4
[INFO ] - Load library 1.8.1 in 421.552 ms.
[INFO ] - Running Benchmark on: cpu().
Loading:     100% |████████████████████████████████████████|
[INFO ] - Model ssd_300_resnet50 loaded in: 250.019 ms.
[INFO ] - Inference result: [0.51217073, -0.12814945, 0.06952064 ...]
[INFO ] - Throughput: 1.01, completed 1 iteration in 995 ms.
[INFO ] - Model loading time: 250.019 ms.

Deprecated Gradle features were used in this build, making it incompatible with Gradle 7.0.
Use '--warning-mode all' to show the individual deprecation warnings.
See https://docs.gradle.org/6.7.1/userguide/command_line_interface.html#sec:command_line_warnings

BUILD SUCCESSFUL in 3s
18 actionable tasks: 1 executed, 17 up-to-date
zhangyun@zhangyundeMacBook-Pro djl % ./gradlew benchmark -Dai.djl.default_engine=PyTorch --args="-c 1 -s 1,3,300,300 -n ai.djl.pytorch:ssd -r {'size':'300','backbone':'resnet50'}"

/Users/zhangyun/IdeaProjects/djl
/Users/zhangyun/IdeaProjects/djl/gradle/wrapper/gradle-wrapper.jar
java
-Xdock:name=Gradle -Xdock:icon=/Users/zhangyun/IdeaProjects/djl/media/gradle.icns -Dorg.gradle.appname=gradlew -classpath /Users/zhangyun/IdeaProjects/djl/gradle/wrapper/gradle-wrapper.jar org.gradle.wrapper.GradleWrapperMain benchmark -Dai.djl.default_engine=PyTorch --args=-c 1 -s 1,3,300,300 -n ai.djl.pytorch:ssd -r {'size':'300','backbone':'resnet50'}

> Task :examples:benchmark
[INFO ] - Number of inter-op threads is 4
[INFO ] - Number of intra-op threads is 4
[INFO ] - Load library 1.8.1 in 263.288 ms.
[INFO ] - Running Benchmark on: cpu().
Loading:     100% |████████████████████████████████████████|
[INFO ] - Model ssd_300_resnet50 loaded in: 239.253 ms.
[INFO ] - Inference result: [0.51217073, -0.12814945, 0.06952064 ...]
[INFO ] - Throughput: 1.03, completed 1 iteration in 975 ms.
[INFO ] - Model loading time: 239.253 ms.

Deprecated Gradle features were used in this build, making it incompatible with Gradle 7.0.
Use '--warning-mode all' to show the individual deprecation warnings.
See https://docs.gradle.org/6.7.1/userguide/command_line_interface.html#sec:command_line_warnings

BUILD SUCCESSFUL in 3s
18 actionable tasks: 1 executed, 17 up-to-date
zhangyun@zhangyundeMacBook-Pro djl %
frankfliu commented 3 years ago

You can test your own mode as well. Here is something be aware:

  1. You should increase the number of iterations: "-c 1000" for benchmark, using "-c 1" doesn't make sense
  2. There is known performance issue in pytorch 1.8.1, we noticed 1.8.1 is about 30% slower then 1.7.1, you can checkout DJL 0.10.0 to test pytorch 1.7.1 performance.
lanking520 commented 3 years ago

Apart from that,

C++: it seemed you are keep using the same input for PyTorch input, can you also put tensor creation as the part to count on time? Also please try to make random inputs use torch::uniform to avoid caching issue.

Java: try to use pure NDArray creation to conduct benchmarking, like only do NDManager.randomUniform() to create NDArray with certain shape. These will help to bring an apple2apple comparison.

Note: SSD is a model we did benchmark everyday, it seemed to be no difference between Python and Java from these weeks. You can try to use our benchmark script to verify: http://docs.djl.ai/master/docs/development/benchmark_with_djl.html

On May 10, 2021, at 8:42 PM, Frank Liu @.***> wrote:



You can test your own mode as well. Here is something be aware:

  1. You should increase the number is iteration: "-c 1000" for benchmark, using "-c 1" doesn't make sense
  2. There is known performance issue in pytorch 1.8.1, we noticed 1.8.1 is about 30% slower then 1.7.1, you can checkout DJL 0.10.0 to test pytorch 1.7.1 performance.

— You are receiving this because you are subscribed to this thread. Reply to this email directly, view it on GitHubhttps://na01.safelinks.protection.outlook.com/?url=https%3A%2F%2Fgithub.com%2Fdeepjavalibrary%2Fdjl%2Fissues%2F951%23issuecomment-837739215&data=04%7C01%7C%7C7027d45f5d5e4f283c4008d9142ec8d7%7C84df9e7fe9f640afb435aaaaaaaaaaaa%7C1%7C0%7C637563013427670568%7CUnknown%7CTWFpbGZsb3d8eyJWIjoiMC4wLjAwMDAiLCJQIjoiV2luMzIiLCJBTiI6Ik1haWwiLCJXVCI6Mn0%3D%7C1000&sdata=GG1XBGihlhtYOEHCR3cToieyWPTVgnG0afXwIOrbD%2B0%3D&reserved=0, or unsubscribehttps://na01.safelinks.protection.outlook.com/?url=https%3A%2F%2Fgithub.com%2Fnotifications%2Funsubscribe-auth%2FAC2XB2QIWEJ42GWXRWDCEVDTNCRR3ANCNFSM44UATL2Q&data=04%7C01%7C%7C7027d45f5d5e4f283c4008d9142ec8d7%7C84df9e7fe9f640afb435aaaaaaaaaaaa%7C1%7C0%7C637563013427680527%7CUnknown%7CTWFpbGZsb3d8eyJWIjoiMC4wLjAwMDAiLCJQIjoiV2luMzIiLCJBTiI6Ik1haWwiLCJXVCI6Mn0%3D%7C1000&sdata=mnLMvhhcVqX%2FczExqS9qVtkAMC8IaZGfQMM9ljFxIYM%3D&reserved=0.

hongyaohongyao commented 3 years ago

onnx的非常快,不过有内存泄露😂 我又试了一下,发现不是内存泄漏,是因为推断速度比较快,内存释放不及时就oom了

lanking520 commented 2 years ago

close this issue due to inactivity, please feel free to reopen if you still have the issue.

visionwxc commented 1 year ago

onnx的非常快,不过有内存泄露😂 我又试了一下,发现不是内存泄漏,是因为推断速度比较快,内存释放不及时就oom了

请问这个有办法解决嘛 ?