Closed zhangyunGit closed 2 years ago
Our benchmark shows DJL pytorch has similar performance as running python. You can see: http://docs.djl.ai/master/docs/development/benchmark_with_djl.html for how we test the performance.
Can you share your code how you test the performance using DJL?
@frankfliu
我的模型是nlp领域的文本匹配esim模型,使用的springboot配合djl搭建的web框架,我单独计算djl的batchPredict时间,多次http请求结果都是90ms左右。
时间统计的代码片段如下
try { long start = new Date().getTime(); List<Classifications> result = predictor.batchPredict(inputs); long end = new Date().getTime(); System.out.println((end-start)); return result; } catch (TranslateException e) { e.printStackTrace(); }
`
整体代码如下:
package com.dfx.test.djlserv;
import ai.djl.Application; import ai.djl.MalformedModelException; import ai.djl.Model; import ai.djl.inference.Predictor; import ai.djl.modality.Classifications; import ai.djl.modality.nlp.SimpleVocabulary; import ai.djl.ndarray.NDArray; import ai.djl.ndarray.NDList; import ai.djl.ndarray.NDManager; import ai.djl.repository.Artifact; import ai.djl.repository.zoo.Criteria; import ai.djl.repository.zoo.ModelNotFoundException; import ai.djl.repository.zoo.ModelZoo; import ai.djl.repository.zoo.ZooModel; import ai.djl.training.util.ProgressBar; import ai.djl.translate.Batchifier; import ai.djl.translate.TranslateException; import ai.djl.translate.Translator; import ai.djl.translate.TranslatorContext; import com.google.common.primitives.Floats; import org.springframework.stereotype.Component; import org.springframework.util.ResourceUtils;
import java.io.FileNotFoundException; import java.io.IOException; import java.net.MalformedURLException; import java.nio.file.Path; import java.nio.file.Paths; import java.util.*; import java.util.stream.Collectors;
@Component public class DJLEsimTest {
private String vocabularyPath;
private String modelPath;
private ZooModel model;
Predictor<QueryDocInput,Classifications> predictor;
public DJLEsimTest(){
init();
}
private void init(){
try {
this.vocabularyPath = ResourceUtils.getFile("classpath:vocab_large.txt").getPath();
this.modelPath = ResourceUtils.getFile("classpath:models").getPath();
this.model = getModel();
this.predictor = this.getPredictor();
} catch (FileNotFoundException e) {
e.printStackTrace();
}
}
public class EsimTranslator implements Translator<QueryDocInput,Classifications>{
private Vocab vocab;
@Override
public void prepare(NDManager manager, Model model) throws IOException {
vocab = new Vocab(vocabularyPath);
}
@Override
public Batchifier getBatchifier() {
return Batchifier.STACK;
}
@Override
public NDList processInput(TranslatorContext translatorContext, QueryDocInput queryDocInput) throws Exception {
String query = queryDocInput.getQuery();
String doc = queryDocInput.getDoc();
long[] queryIds = Arrays.stream(this.vocab.text2Ids(query)).mapToLong(Integer::longValue).toArray();
long[] docIds = Arrays.stream(this.vocab.text2Ids(doc)).mapToLong(Integer::longValue).toArray();
long queryLen = Vocab.seqLen.longValue();
long docLen = vocab.seqLen.longValue();
NDManager manager = translatorContext.getNDManager();
NDArray queryArray = manager.create(queryIds);
NDArray docArray = manager.create(docIds);
NDArray queryLenArray = manager.create(queryLen);
NDArray docLenArray = manager.create(docLen);
return new NDList(queryArray,queryLenArray,docArray,docLenArray);
}
@Override
public Classifications processOutput(TranslatorContext translatorContext, NDList ndList) throws Exception {
NDArray probArray = ndList.get(1);
List<String> classNames = Arrays.asList("0","1");
List<Double> probabilities = Floats.asList(probArray.toFloatArray()).stream().mapToDouble(Float::doubleValue).boxed().collect(Collectors.toList());
Classifications classifications = new Classifications(classNames,probabilities);
return classifications;
}
}
private ZooModel getModel(){
EsimTranslator translator = new EsimTranslator();
//System.setProperty("ai.djl.pytorch:pytorch-model-zoo", "build/models/trace_esim_model");
try {
Criteria<QueryDocInput, Classifications> criteria = Criteria.builder()
.setTypes(QueryDocInput.class, Classifications.class)
//.optModelPath(Paths.get("build/models/trace_esim_model/")) // search in local folder
.optModelUrls(modelPath)
.optModelName("trace_esim_model")
.optTranslator(translator)
.optProgress(new ProgressBar()).build();
ZooModel model = ModelZoo.loadModel(criteria);
return model;
} catch (MalformedURLException e) {
e.printStackTrace();
} catch (MalformedModelException e) {
e.printStackTrace();
} catch (ModelNotFoundException e) {
e.printStackTrace();
} catch (IOException e) {
e.printStackTrace();
}
return null;
}
public Classifications predict(String query,String doc){
QueryDocInput input = new QueryDocInput(query,doc);
try {
return predictor.predict(input);
} catch (TranslateException e) {
e.printStackTrace();
}
return null;
}
public List<Classifications> batchPredict(List<QueryDocInput> inputs){
// padInputs(inputs);
try {
long start = new Date().getTime();
List<Classifications> result = predictor.batchPredict(inputs);
long end = new Date().getTime();
System.out.println((end-start));
return result;
} catch (TranslateException e) {
e.printStackTrace();
}
return null;
}
private void padInputs(List<QueryDocInput> inputs){
if(inputs.size()<50){
String padUnk = "[UNK][UNK][UNK][UNK][UNK][UNK][UNK][UNK][UNK][UNK]";
QueryDocInput padInput = new QueryDocInput(padUnk,padUnk);
while (inputs.size()<50){
inputs.add(padInput);
}
}
}
private Predictor<QueryDocInput,Classifications> getPredictor(){
EsimTranslator translator = new EsimTranslator();
try {
Predictor<QueryDocInput,Classifications> predictor = this.model.newPredictor(translator);
return predictor;
} catch (Exception e) {
e.printStackTrace();
}
return null;
}
} `
我在我本地试了3次benchmark,第一次比较慢,后面2次比较快
zhangyun@zhangyundeMacBook-Pro djl % ./gradlew benchmark -Dai.djl.default_engine=PyTorch --args="-c 1 -s 1,3,300,300 -n ai.djl.pytorch:ssd -r {'size':'300','backbone':'resnet50'}"
/Users/zhangyun/IdeaProjects/djl
/Users/zhangyun/IdeaProjects/djl/gradle/wrapper/gradle-wrapper.jar
java
-Xdock:name=Gradle -Xdock:icon=/Users/zhangyun/IdeaProjects/djl/media/gradle.icns -Dorg.gradle.appname=gradlew -classpath /Users/zhangyun/IdeaProjects/djl/gradle/wrapper/gradle-wrapper.jar org.gradle.wrapper.GradleWrapperMain benchmark -Dai.djl.default_engine=PyTorch --args=-c 1 -s 1,3,300,300 -n ai.djl.pytorch:ssd -r {'size':'300','backbone':'resnet50'}
> Task :examples:benchmark
[INFO ] - Number of inter-op threads is 4
[INFO ] - Number of intra-op threads is 4
[INFO ] - Load library 1.8.1 in 420.746 ms.
[INFO ] - Running Benchmark on: cpu().
Downloading: 100% |████████████████████████████████████████|
Loading: 100% |████████████████████████████████████████|
[INFO ] - Model ssd_300_resnet50 loaded in: 6480.265 ms.
[INFO ] - Inference result: [0.51217073, -0.12814945, 0.06952064 ...]
[INFO ] - Throughput: 0.14, completed 1 iteration in 7225 ms.
[INFO ] - Model loading time: 6480.265 ms.
Deprecated Gradle features were used in this build, making it incompatible with Gradle 7.0.
Use '--warning-mode all' to show the individual deprecation warnings.
See https://docs.gradle.org/6.7.1/userguide/command_line_interface.html#sec:command_line_warnings
BUILD SUCCESSFUL in 9s
18 actionable tasks: 1 executed, 17 up-to-date
zhangyun@zhangyundeMacBook-Pro djl % ./gradlew benchmark -Dai.djl.default_engine=PyTorch --args="-c 1 -s 1,3,300,300 -n ai.djl.pytorch:ssd -r {'size':'300','backbone':'resnet50'}"
/Users/zhangyun/IdeaProjects/djl
/Users/zhangyun/IdeaProjects/djl/gradle/wrapper/gradle-wrapper.jar
java
-Xdock:name=Gradle -Xdock:icon=/Users/zhangyun/IdeaProjects/djl/media/gradle.icns -Dorg.gradle.appname=gradlew -classpath /Users/zhangyun/IdeaProjects/djl/gradle/wrapper/gradle-wrapper.jar org.gradle.wrapper.GradleWrapperMain benchmark -Dai.djl.default_engine=PyTorch --args=-c 1 -s 1,3,300,300 -n ai.djl.pytorch:ssd -r {'size':'300','backbone':'resnet50'}
> Task :examples:benchmark
[INFO ] - Number of inter-op threads is 4
[INFO ] - Number of intra-op threads is 4
[INFO ] - Load library 1.8.1 in 421.552 ms.
[INFO ] - Running Benchmark on: cpu().
Loading: 100% |████████████████████████████████████████|
[INFO ] - Model ssd_300_resnet50 loaded in: 250.019 ms.
[INFO ] - Inference result: [0.51217073, -0.12814945, 0.06952064 ...]
[INFO ] - Throughput: 1.01, completed 1 iteration in 995 ms.
[INFO ] - Model loading time: 250.019 ms.
Deprecated Gradle features were used in this build, making it incompatible with Gradle 7.0.
Use '--warning-mode all' to show the individual deprecation warnings.
See https://docs.gradle.org/6.7.1/userguide/command_line_interface.html#sec:command_line_warnings
BUILD SUCCESSFUL in 3s
18 actionable tasks: 1 executed, 17 up-to-date
zhangyun@zhangyundeMacBook-Pro djl % ./gradlew benchmark -Dai.djl.default_engine=PyTorch --args="-c 1 -s 1,3,300,300 -n ai.djl.pytorch:ssd -r {'size':'300','backbone':'resnet50'}"
/Users/zhangyun/IdeaProjects/djl
/Users/zhangyun/IdeaProjects/djl/gradle/wrapper/gradle-wrapper.jar
java
-Xdock:name=Gradle -Xdock:icon=/Users/zhangyun/IdeaProjects/djl/media/gradle.icns -Dorg.gradle.appname=gradlew -classpath /Users/zhangyun/IdeaProjects/djl/gradle/wrapper/gradle-wrapper.jar org.gradle.wrapper.GradleWrapperMain benchmark -Dai.djl.default_engine=PyTorch --args=-c 1 -s 1,3,300,300 -n ai.djl.pytorch:ssd -r {'size':'300','backbone':'resnet50'}
> Task :examples:benchmark
[INFO ] - Number of inter-op threads is 4
[INFO ] - Number of intra-op threads is 4
[INFO ] - Load library 1.8.1 in 263.288 ms.
[INFO ] - Running Benchmark on: cpu().
Loading: 100% |████████████████████████████████████████|
[INFO ] - Model ssd_300_resnet50 loaded in: 239.253 ms.
[INFO ] - Inference result: [0.51217073, -0.12814945, 0.06952064 ...]
[INFO ] - Throughput: 1.03, completed 1 iteration in 975 ms.
[INFO ] - Model loading time: 239.253 ms.
Deprecated Gradle features were used in this build, making it incompatible with Gradle 7.0.
Use '--warning-mode all' to show the individual deprecation warnings.
See https://docs.gradle.org/6.7.1/userguide/command_line_interface.html#sec:command_line_warnings
BUILD SUCCESSFUL in 3s
18 actionable tasks: 1 executed, 17 up-to-date
zhangyun@zhangyundeMacBook-Pro djl %
You can test your own mode as well. Here is something be aware:
Apart from that,
C++: it seemed you are keep using the same input for PyTorch input, can you also put tensor creation as the part to count on time? Also please try to make random inputs use torch::uniform to avoid caching issue.
Java: try to use pure NDArray creation to conduct benchmarking, like only do NDManager.randomUniform() to create NDArray with certain shape. These will help to bring an apple2apple comparison.
Note: SSD is a model we did benchmark everyday, it seemed to be no difference between Python and Java from these weeks. You can try to use our benchmark script to verify: http://docs.djl.ai/master/docs/development/benchmark_with_djl.html
On May 10, 2021, at 8:42 PM, Frank Liu @.***> wrote:
You can test your own mode as well. Here is something be aware:
— You are receiving this because you are subscribed to this thread. Reply to this email directly, view it on GitHubhttps://na01.safelinks.protection.outlook.com/?url=https%3A%2F%2Fgithub.com%2Fdeepjavalibrary%2Fdjl%2Fissues%2F951%23issuecomment-837739215&data=04%7C01%7C%7C7027d45f5d5e4f283c4008d9142ec8d7%7C84df9e7fe9f640afb435aaaaaaaaaaaa%7C1%7C0%7C637563013427670568%7CUnknown%7CTWFpbGZsb3d8eyJWIjoiMC4wLjAwMDAiLCJQIjoiV2luMzIiLCJBTiI6Ik1haWwiLCJXVCI6Mn0%3D%7C1000&sdata=GG1XBGihlhtYOEHCR3cToieyWPTVgnG0afXwIOrbD%2B0%3D&reserved=0, or unsubscribehttps://na01.safelinks.protection.outlook.com/?url=https%3A%2F%2Fgithub.com%2Fnotifications%2Funsubscribe-auth%2FAC2XB2QIWEJ42GWXRWDCEVDTNCRR3ANCNFSM44UATL2Q&data=04%7C01%7C%7C7027d45f5d5e4f283c4008d9142ec8d7%7C84df9e7fe9f640afb435aaaaaaaaaaaa%7C1%7C0%7C637563013427680527%7CUnknown%7CTWFpbGZsb3d8eyJWIjoiMC4wLjAwMDAiLCJQIjoiV2luMzIiLCJBTiI6Ik1haWwiLCJXVCI6Mn0%3D%7C1000&sdata=mnLMvhhcVqX%2FczExqS9qVtkAMC8IaZGfQMM9ljFxIYM%3D&reserved=0.
onnx的非常快,不过有内存泄露😂 我又试了一下,发现不是内存泄漏,是因为推断速度比较快,内存释放不及时就oom了
close this issue due to inactivity, please feel free to reopen if you still have the issue.
onnx的非常快,不过有内存泄露😂 我又试了一下,发现不是内存泄漏,是因为推断速度比较快,内存释放不及时就oom了
请问这个有办法解决嘛 ?
Description
使用djl加载torchscript转换的pt模型,发现推断性能很低,比直接使用python加载模型推断的方式下降约8倍。
为了确认是我pt模型的问题,还是djl框架的问题,我使用c++加载我的pt模型进行推断比较,以下只统计了forward的时间,代码片段如下:
for(int i=1;i<10;i=i+1){ start = clock();//1计时开始 output1 = net.forward({input1,input2,input3,input4}); end = clock();//1计时开始 std::cout << "The run time" << i << " is: " <<(double)(end - start) / CLOCKS_PER_SEC << "s" << std::endl; std::cout << "output1: " << output1 << std::endl; }
结果如下:如上图,第1,2次推断性能都很低,分别为92ms,72ms,从第3次开始性能下降到11ms-14ms
我从pytorch的github上了解到: We compile the graph for each set of different tensor dimensions that are run and then cache it, so it's likely the first run will be slower.详情如下: https://github.com/pytorch/pytorch/issues/19106
由于djl每次推断的时间与c++加载后第一次推断的时间差不多,都是90ms左右。所以我怀疑djl使用java的jni(or jna?)加载torchlib dll的方式,使用每次都需要重新加载再推断,从而没有办法很好的利用上文说的tensor cache,所以每次都很慢。
想请教一下,当前是否有相关的优化配置可以使用,从而解决上述性能问题? 多谢。
Environment Info
Please run the command
./gradlew debugEnv
from the root directory of DJL (if necessary, clone DJL first). It will output information about your system, environment, and installation that can help us debug your issue. Paste the output of the command below: