deepjavalibrary / djl

An Engine-Agnostic Deep Learning Framework in Java
https://djl.ai
Apache License 2.0
4.03k stars 642 forks source link

DJL 0.23.0 + torch 2.0.1 GPU multi-threading inference issue #2778

Open jestiny0 opened 10 months ago

jestiny0 commented 10 months ago

Description

I have updated DJL to version 0.23.0 and PyTorch to version 2.0.1. However, I encountered an issue with infinite end-to-end latency increase when performing multi-threaded inference using GPU. This seems to be a memory leak. Currently, I am using DJL 0.22.0 with PyTorch 2.0.0, and I did not encounter any issues in the same stress testing environment.

I have looked into several similar issues, and some of them mentioned issues with both PyTorch 2.0.0 and PyTorch 2.0.1 regarding multi-threaded GPU inference. However, I personally did not encounter any issues with PyTorch 2.0.0.

Furthermore, I tried setting export TORCH_CUDNN_V8_API_DISABLED=1 based on the referenced issue, but it did not resolve the problem for me.

My point

I noticed that the master code of DJL has already set PyTorch 2.0.1 as the default version. I'm curious to know if you have made any modifications to address this issue, or if there are other plans in place?

frankfliu commented 10 months ago

This is a bug in 0.23.0 container, please set the following environment variable:

NO_OMP_NUM_THREADS=true

See this PR: https://github.com/deepjavalibrary/djl-serving/pull/1073

jestiny0 commented 10 months ago

@frankfliu This is of no use to me. I did not use DJL-serving and the containers. I directly used pytorch-engine to call the torchScript model. I tried setting the environment variable NO_OMP_NUM_THREADS=true or OMP_NUM_THREADS=1, but it still didn't work.

I want to know what is the root cause of this problem and the solution. I am worried that my djl version will not be able to continue to be upgraded because of this problem.

frankfliu commented 10 months ago

@jestiny0 First of all, you can use older version of PyTorch with 0.23.0.

Did you set thread configuration for PyTorch: https://docs.djl.ai/docs/development/inference_performance_optimization.html#thread-configuration?

You might also want to take look this: https://docs.djl.ai/docs/development/inference_performance_optimization.html#graph-executor-optimization

Are you running on GPU? if use CUDA, you must set: TORCH_CUDNN_V8_API_DISABLED=1

Did you try djl-bench? You can run stress test with different PyTorch version

jestiny0 commented 10 months ago

@frankfliu Using an older version can certainly solve my problem. The current situation is that I am using DJL 0.22.0+torch2.0.0 with both GPU and CPU. I have both GPU and CPU models running online. I have configured the thread settings and graph optimization, and the online models are running fine.

However, now we are preparing to upgrade our offline models to pytorch 2.0.1. Therefore, I tried to upgrade the online service to the latest DJL 0.23.0+torch2.0.1. It performs well on CPU, but encounters severe performance and memory issues on GPU. I have tried adding TORCH_CUDNN_V8_API_DISABLED=1, but it didn't work for me. I am concerned that the GPU performance issue will prevent me from upgrading the torch version of the online service to 2.0.1.

jestiny0 commented 10 months ago

@frankfliu Do you have any better suggestions or points worth trying? Upgrading torch is quite urgent for us. Thank you.

frankfliu commented 10 months ago

@jestiny0

Can you run djl-bench with your model and see if you can reproduce the issue.

We didn't observe the performance issue with PyTorch 2.0.1. Can you share your model?

jestiny0 commented 10 months ago

@frankfliu Sorry, let me correct my issue. My issue is that after upgrading to the latest version, there has been a significant decrease in performance when serving on GPU.

I did not use djl-bench because my model has complex inputs including multiple inputs and dictionaries. I have provided code to generate the model: https://github.com/jestiny0/djl_gpu_multithreads_issue_model.

I am using the djl pytorch engine for deployment and serving, but I have not provided the Java code for that part because it is complex.

djl configs:

<dependency>
      <groupId>ai.djl</groupId>
      <artifactId>api</artifactId>
      <version>${djl.version}</version>
    </dependency>
    <dependency>
      <groupId>ai.djl.pytorch</groupId>
      <artifactId>pytorch-engine</artifactId>
      <version>${djl.version}</version>
    </dependency>

    <dependency>
      <groupId>ai.djl.pytorch</groupId>
      <artifactId>pytorch-native-cu118</artifactId>
      <classifier>linux-x86_64</classifier>
      <version>${torch.version}</version>
      <scope>runtime</scope>
    </dependency>
    <dependency>
      <groupId>ai.djl.pytorch</groupId>
      <artifactId>pytorch-jni</artifactId>
      <version>${torch.version}-${djl.version}</version>
      <scope>runtime</scope>
    </dependency>
    <!-- DJL dependencies above -->

The configuration of the version running well online is as follows:

<properties>
    <djl.version>0.22.0</djl.version>
    <torch.version>2.0.0</torch.version>
  </properties>

the stress test latency is shown as follows: image

The configuration of the version experiencing performance issues is as follows:

<djl.version>0.23.0</djl.version>
    <torch.version>2.0.1</torch.version>

the stress test latency is shown as follows: image

It can be observed that there is an increase in latency, and when the stress test QPS reaches 150, the old version can still run normally, but the new version experiences a significant increase in end-to-end latency. The Java threads are heavily queued, making it unable to run normally.

image
jestiny0 commented 10 months ago

@frankfliu Looking forward to your new discoveries and suggestions, thanks!

frankfliu commented 10 months ago

@jestiny0 Sorry for the delay. Can you share your java code that run the inference?

jestiny0 commented 10 months ago

@frankfliu Sorry for not being able to reply earlier due to the holiday season. The Java code is quite complex, but here are some main sections. The following code is responsible for loading the model and will only be executed once:

        PtModel model = (PtModel) new PtEngineProvider().getEngine().newModel(
            config.getName(),
            config.getDevice());
        try {
            Map<String, String> options = Maps.newHashMap();
            options.put("extraFiles", PyTorchPredictor.MODEL_METADATA_KEY);
            if (config.getDevice().isGpu()) {
                options.put("mapLocation", "true");
            }
            model.load(Paths.get(config.getPath()), null, options);
        } catch (MalformedModelException | IOException e) {
            log.error("Failed load pytorch predictor", e);
            throw new RuntimeException("Failed load pytorch predictor", e);
        }

Here is the code for constructing the request. The provided model only has the tensor_feature field populated:

Map<String, PtNDArray> tensorFeatures = Maps.newHashMap();
// ...
for (FeatureInfo featureInfo : featureInfos) {
            // ... other features
            if (featureInfo.getDataType() == DataType.TENSOR) {
                extractTensorFeature(tensorFeatures, featureInfo);
            }
            else {
                throw new IllegalStateException("Unknown feature info data type");
            }
        }

// ......
IValue denseFeaturesIValue = IValue.stringMapFrom(denseFeatures);
        IValue sparseFeaturesIValue = IValue.listFrom(sparseFeatures.stream().toArray(IValue[]::new));
        IValue embeddingFeaturesIValue = IValue.stringMapFrom(embeddingFeatures);
        IValue[] result;
        IValue tensorFeatureIValue = IValue.stringMapFrom(tensorFeatures);
        result = new IValue[] {
                denseFeaturesIValue,
                sparseFeaturesIValue,
                embeddingFeaturesIValue,
                tensorFeatureIValue

        context.setProcessedInput(result);
        context.setIntermediateIValues(intermediateIValues);

        return context;
    private void extractTensorFeature(
            Map<String, PtNDArray> tensorFeatures,
            Optional<FeatureValues> featureValues,
            FeatureInfo info)
        // ....
        NDManager ndManager = ptModel.getNDManager().newSubManager();
        Tensor tensorFeature = featureValues.get().getTensorFeatures();
        PtNDArray tensor = (PtNDArray) threadLocalBuffers.get().allocateTensor(tensorFeature, info, ndManager);

        tensorFeatures.put(
            info.getName(), 
            tensor
        );
        // ....
    }

code that run the inference:

protected PyTorchPredictionContext predictInternal(PyTorchPredictionContext context) {
        IValue[] inputIValues = context.getProcessedInput();
        IValue result = ((PtSymbolBlock) ptModel.getBlock()).forward(inputIValues);
        context.setPreprocessOutput(result);
        Arrays.stream(inputIValues).forEach(ivalue -> ivalue.close());
        context.getIntermediateIValues().stream().forEach(ivalue -> ivalue.close());
        return context;
    }

It is worth mentioning that besides the code for model loading, which only runs once, the code for constructing requests and inference is executed within separate threads(thread pool). This means that different requests are processed concurrently.

Besides, the above Java code remains unchanged before and after upgrading the DJL version.

frankfliu commented 9 months ago

@jestiny0

It's hard for me to figure your input from the code. Can you write code to create empty tensors for the input (like. manager.zeros(new Shape(1, 1000))) for each input? Or you can just let me know the shape:

denseFeaturesIValue: (? , ?) sparseFeaturesIValue: [(?, ?), (?, ?)] embeddingFeaturesIValue: { "key1": ?, "key2": ? } tensorFeatureIValue: { "key1": ?, "key2": ? }

jestiny0 commented 9 months ago

@frankfliu Here is a list of all the features. You can see the shape and type of each feature in the tensor_properties attribute of each line. If the first element of the shape is -1, it represents the batch size, which you can set to 200. For example, the shape of feature check is [200], the shape of feature local_x_semantic_emb is [256, 50], and the shape of feature y_pid is [200, 3].

Here is the code for generating input in Python. You can refer to it and modify it to use DJL.

All features are present in the tensorFeatureIValue. The other three inputs, denseFeaturesIValue, sparseFeaturesIValue, and embeddingFeaturesIValue, are all empty.

jestiny0 commented 8 months ago

@frankfliu I'm sorry for bothering you again after such a long time. I wanted to check if there has been any progress on the issue mentioned earlier. I have further tested with versions djl0.24.0 + torch2.0.1 and djl0.25.0 + torch2.0.1, but the issue still existed. Below are my latest load test results:

Index Environment QPS Latency P50 Latency P99 GPU Utilization
1 baseline(djl0.22.0+torch2.0.0) 100 5 7 28
2 baseline(djl0.22.0+torch2.0.0) 250 8 13 71
3 djl0.24.0+torch2.0.1 100 7 9 40
4 djl0.24.0+torch2.0.1 250 34 52 81

As we can see, under the 100qps scenario, the latency has increased for the new version. Under the 250qps scenario, the latency has increased significantly, and the GPU is unable to handle the load. This leads to a continuous build-up of requests in the Java thread pool queue, resulting in an infinite increase in end-to-end latency.

By the way, the data I provided above is from load testing, not real production traffic. I'm planning to divert a part of real traffic to observe the actual performance. Nonetheless, I'm still looking forward to any further findings or insights you may have to share.

jestiny0 commented 8 months ago

@frankfliu I conducted an AB experiment in production to further confirm the issue. In the model part, no changes were made. Using the same AWS g5.2xlarge instance (8 CPUs, 1 GPU), compared to the baseline, during peak hours with similar CPU utilization, the average latency in the experiment was 58 ms and the p99 latency was 80 ms. In the baseline, the average latency was 44 ms and the p99 latency was 58 ms. There is a difference of approximately 30% in throughput.

Now I plan to narrow down the investigation scope by mainly focusing on the upgrade from DJL version 0.22.0 to 0.23.0 and the corresponding pytorch-jni upgrade from 2.0.0-0.22.0 to 2.0.1-0.23.0. I want to check what changes DJL has made between these versions (as it's difficult to troubleshoot PyTorch changes). If you have any further discoveries, I would appreciate it if you could share them