deepjavalibrary / djl

An Engine-Agnostic Deep Learning Framework in Java
https://djl.ai
Apache License 2.0
4.16k stars 661 forks source link

Regression: EngineException: default_program(22): error: extra text after expected end of number with DJL 0.28.0 + intfloat/multilingual-e5-small on machine with GPU #3089

Open david-sitsky opened 7 months ago

david-sitsky commented 7 months ago

Description

Trying to perform predictions using intfloat/multilingual-e5-small fails on a machine with a GPU. This used to work in DJL 0.26.0 using PY_TORCH 2.0.1 but now fails on 0.28.0 (and presumably 0.27.0).

Expected Behavior

It performs a prediction without error.

Error Message

Caused by: ai.djl.translate.TranslateException: ai.djl.engine.EngineException: default_program(22): error: extra text after expected end of number
      aten_mul[(long long)(threadIdx.x) + 512ll * (long long)(blockIdx.x)] = v * -3.402823466385289e+38.f;
                                                                                                       ^

default_program(26): error: extra text after expected end of number
      aten_add[(long long)(threadIdx.x) + 512ll * (long long)(blockIdx.x)] = v_1 / 5.656854152679443f + v_2 * -3.402823466385289e+38.f;
                                                                                                                                    ^

2 errors detected in the compilation of "default_program".

nvrtc compilation failed: 

#define NAN __int_as_float(0x7fffffff)
#define POS_INFINITY __int_as_float(0x7f800000)
#define NEG_INFINITY __int_as_float(0xff800000)

template<typename T>
__device__ T maximum(T a, T b) {
  return isnan(a) ? a : (a > b ? a : b);
}

template<typename T>
__device__ T minimum(T a, T b) {
  return isnan(a) ? a : (a < b ? a : b);
}

extern "C" __global__
void fused_mul_div_add(float* tattention_scores_2, float* tv_, float* aten_add, float* aten_mul) {
{
if ((long long)(threadIdx.x) + 512ll * (long long)(blockIdx.x)<5ll ? 1 : 0) {
    float v = __ldg(tv_ + (long long)(threadIdx.x) + 512ll * (long long)(blockIdx.x));
    aten_mul[(long long)(threadIdx.x) + 512ll * (long long)(blockIdx.x)] = v * -3.402823466385289e+38.f;
  }if ((long long)(threadIdx.x) + 512ll * (long long)(blockIdx.x)<300ll ? 1 : 0) {
    float v_1 = __ldg(tattention_scores_2 + (long long)(threadIdx.x) + 512ll * (long long)(blockIdx.x));
    float v_2 = __ldg(tv_ + ((long long)(threadIdx.x) + 512ll * (long long)(blockIdx.x)) % 5ll);
    aten_add[(long long)(threadIdx.x) + 512ll * (long long)(blockIdx.x)] = v_1 / 5.656854152679443f + v_2 * -3.402823466385289e+38.f;
  }}
}

        at ai.djl.inference.Predictor.batchPredict(Predictor.java:195) ~[api-0.28.0-SNAPSHOT.jar:?]

This seems similar to what is reported here: https://github.com/deepjavalibrary/djl/issues/2962, but according to https://github.com/deepjavalibrary/djl/blob/master/engines/pytorch/pytorch-engine/README.md, both DJL 0.27.0 and 0.28.0 no longer support PY_TORCH 2.0.1. For fun I tried, but it does indeed fail:

ai.djl.engine.EngineException: Cannot download jni files: https://publish.djl.ai/pytorch/2.0.1/jnilib/0.28.0/linux-x86_64/cu118/libdjl_torch.so
    at ai.djl.pytorch.jni.LibUtils.downloadJniLib(LibUtils.java:542) ~[pytorch-engine-0.28.0-SNAPSHOT.jar:?]
    at ai.djl.pytorch.jni.LibUtils.findJniLibrary(LibUtils.java:276) ~[pytorch-engine-0.28.0-SNAPSHOT.jar:?]
    at ai.djl.pytorch.jni.LibUtils.loadLibrary(LibUtils.java:84) ~[pytorch-engine-0.28.0-SNAPSHOT.jar:?]
    at ai.djl.pytorch.engine.PtEngine.newInstance(PtEngine.java:53) ~[pytorch-engine-0.28.0-SNAPSHOT.jar:?]
    at ai.djl.pytorch.engine.PtEngineProvider.getEngine(PtEngineProvider.java:41) ~[pytorch-engine-0.28.0-SNAPSHOT.jar:?]
    at ai.djl.engine.Engine.getEngine(Engine.java:190) ~[api-0.28.0-SNAPSHOT.jar:?]
    at ai.djl.Model.newInstance(Model.java:99) ~[api-0.28.0-SNAPSHOT.jar:?]
    at ai.djl.repository.zoo.BaseModelLoader.createModel(BaseModelLoader.java:196) ~[api-0.28.0-SNAPSHOT.jar:?]
    at ai.djl.repository.zoo.BaseModelLoader.loadModel(BaseModelLoader.java:159) ~[api-0.28.0-SNAPSHOT.jar:?]
    at ai.djl.repository.zoo.Criteria.loadModel(Criteria.java:172) ~[api-0.28.0-SNAPSHOT.jar:?]

@frankfliu - many thanks for all your recent fixes, but it is not clear what can be done in this situation other than pytorch fixing https://github.com/pytorch/pytorch/issues/107503. Or is there a workaround? Many thanks in advance.

frankfliu commented 7 months ago

You can still use DJL 0.28.0 with PyTorch 2.0.1:

export PYTORCH_VERSION=2.0.1
david-sitsky commented 7 months ago

I tried.. but I mentioned that in the above. It fails with this:

ai.djl.engine.EngineException: Cannot download jni files: https://publish.djl.ai/pytorch/2.0.1/jnilib/0.28.0/linux-x86_64/cu118/libdjl_torch.so
    at ai.djl.pytorch.jni.LibUtils.downloadJniLib(LibUtils.java:542) ~[pytorch-engine-0.28.0-SNAPSHOT.jar:?]
    at ai.djl.pytorch.jni.LibUtils.findJniLibrary(LibUtils.java:276) ~[pytorch-engine-0.28.0-SNAPSHOT.jar:?]
frankfliu commented 7 months ago

@david-sitsky

Please try it again. I added 2.0.1 support for 0.28.0-SNAPSHOT

david-sitsky commented 7 months ago

Many thanks @frankfliu for your quick reply - that works.

Hopefully we can somehow convince PyTorch to fix this issue, as staying pegged to 2.0.1 is not a great long-term solution. They claim libtorch is in maintenance mode, but this is a regression.

frankfliu commented 7 months ago

torch.export is still in experimental mode, PyTorch should not put torch.jit in maintenance mode until torch.export is ready.

david-sitsky commented 7 months ago

@frankfliu - this worked fine on Linux, but on Windows (I have to support this platform too sadly), it fails with what looks like the same error, even when using PyTorch 2.0.1. The logs below confirm that version of PyTorch is being used. I installed NVidia Toolkit 11.8 and cuDNN 8.9.7 which I read are compatible with each other and that version of PyTorch. Is there something else I have done wrong here? Thanks again for any advise.

2024-04-16 05:25:05.759 +0000 [main] 121845 WARN  ai.djl.pytorch.jni.LibUtils - Override PyTorch version: 2.0.1.
2024-04-16 05:25:06.269 +0000 [main] 122355 INFO  ai.djl.pytorch.engine.PtEngine - PyTorch graph executor optimizer is enabled, this may impact your inference latency and throughput. See: https://docs.djl.ai/docs/development/inference_performance_optimization.html#graph-executor-optimization
2024-04-16 05:25:06.278 +0000 [main] 122364 INFO  ai.djl.pytorch.engine.PtEngine - Number of inter-op threads is 16
2024-04-16 05:25:06.278 +0000 [main] 122364 INFO  ai.djl.pytorch.engine.PtEngine - Number of intra-op threads is 16
2024-04-16 05:25:06.717 +0000 [main] 122803 INFO  ai.djl.util.Platform - Found matching platform from: jar:file:/C:/Program%20Files/.../lib/tokenizers-0.28.0-SNAPSHOT.jar!/native/lib/tokenizers.properties
java.io.UncheckedIOException: Unexpected error
...
Caused by: ai.djl.translate.TranslateException: ai.djl.engine.EngineException: default_program(22): error: extra text after expected end of number

default_program(26): error: extra text after expected end of number

2 errors detected in the compilation of "default_program".

nvrtc compilation failed: 

#define NAN __int_as_float(0x7fffffff)
#define POS_INFINITY __int_as_float(0x7f800000)
#define NEG_INFINITY __int_as_float(0xff800000)

template<typename T>
__device__ T maximum(T a, T b) {
  return isnan(a) ? a : (a > b ? a : b);
}

template<typename T>
__device__ T minimum(T a, T b) {
  return isnan(a) ? a : (a < b ? a : b);
}

extern "C" __global__
void fused_mul_div_add(float* tattention_scores_2, float* tv_, float* aten_add, float* aten_mul) {
{
if ((long long)(threadIdx.x) + 512ll * (long long)(blockIdx.x)<5ll ? 1 : 0) {
    float v = __ldg(tv_ + (long long)(threadIdx.x) + 512ll * (long long)(blockIdx.x));
    aten_mul[(long long)(threadIdx.x) + 512ll * (long long)(blockIdx.x)] = v * -3.402823466385289e+38.f;
  }if ((long long)(threadIdx.x) + 512ll * (long long)(blockIdx.x)<300ll ? 1 : 0) {
    float v_1 = __ldg(tattention_scores_2 + (long long)(threadIdx.x) + 512ll * (long long)(blockIdx.x));
    float v_2 = __ldg(tv_ + ((long long)(threadIdx.x) + 512ll * (long long)(blockIdx.x)) % 5ll);
    aten_add[(long long)(threadIdx.x) + 512ll * (long long)(blockIdx.x)] = v_1 / 5.656854152679443f + v_2 * -3.402823466385289e+38.f;
  }}
}

    at ai.djl.inference.Predictor.batchPredict(Predictor.java:195) ~[api-0.28.0-SNAPSHOT.jar:?]
        ...
    ... 14 more
Caused by: ai.djl.engine.EngineException: default_program(22): error: extra text after expected end of number

default_program(26): error: extra text after expected end of number

2 errors detected in the compilation of "default_program".
frankfliu commented 7 months ago

Can you convert the model to onnx?

david-sitsky commented 7 months ago

I can try that route, but I was hoping to avoid it, as the current code will dynamically download the HuggingFace models to the user's machine which is super convenient. I was hoping to avoid having to ship model files explicitly.

Any ideas why Windows is affected like this? I thought this bug was PyTorch specific, so it seems really odd that 2.0.1 is showing this issue even on Windows.

david-sitsky commented 7 months ago

@frankfliu, FWIW using onnx seemed to work which is great! I saw the recent commit adding support for converting HuggingFace models to onnx: https://github.com/deepjavalibrary/djl/pull/3093/files. @xyang16, as a FYI while this worked, I got an error from model_zoo_importer.py you might want to fix:

djl/extensions/tokenizers$ python3 src/main/python/model_zoo_importer.py -m intfloat/multilingual-e5-small -f OnnxRuntime
...
Validating ONNX model tmp/model.onnx...
    -[✓] ONNX model output names match reference model (sentence_embedding, token_embeddings)
    - Validating ONNX Model output "token_embeddings":
        -[✓] (2, 16, 384) matches (2, 16, 384)
        -[✓] all values close (atol: 1e-05)
    - Validating ONNX Model output "sentence_embedding":
        -[✓] (2, 384) matches (2, 384)
        -[✓] all values close (atol: 1e-05)
The ONNX export succeeded and the exported model was saved at: tmp
...
Saving DJL model as zip: multilingual-e5-small.zip ...
Failed to convert model: intfloat/multilingual-e5-small.
cannot unpack non-iterable NoneType object
Traceback (most recent call last):
  File "/data/djl/djl/extensions/tokenizers/src/main/python/model_zoo_importer.py", line 54, in main
    result, reason, size = converter.save_model(
TypeError: cannot unpack non-iterable NoneType object
finished.

Strangely, after changing things to use onnx for E5, this PyTorch model djl://ai.djl.huggingface.pytorch/sentence-transformers/clip-ViT-B-32-multilingual-v1 is now failing to load when it used to work. Perhaps this is related to the bug for "text embedding translator regression" you just fixed here? https://github.com/deepjavalibrary/djl/pull/3095. In any case, I might covert sentence-transformers/clip-ViT-B-32-multilingual-v1 to onnx so I can avoid using PyTorch completely. Any ideas on what might be causing this?

ai.djl.engine.EngineException: Could not run 'aten::empty_strided' with arguments from the 'CUDA' backend. This could be because the operator doesn't exist for this backend, or was omitted during the selective/custom build process (if using custom build). If you are a Facebook employee using PyTorch on mobile, please visit https://fburl.com/ptmfixes for possible resolutions. 'aten::empty_strided' is only available for these backends: [CPU, Meta, QuantizedCPU, BackendSelect, Python, FuncTorchDynamicLayerBackMode, Functionalize, Named, Conjugate, Negative, ZeroTensor, ADInplaceOrView, AutogradOther, AutogradCPU, AutogradCUDA, AutogradHIP, AutogradXLA, AutogradMPS, AutogradIPU, AutogradXPU, AutogradHPU, AutogradVE, AutogradLazy, AutogradMTIA, AutogradPrivateUse1, AutogradPrivateUse2, AutogradPrivateUse3, AutogradMeta, AutogradNestedTensor, Tracer, AutocastCPU, AutocastCUDA, FuncTorchBatched, FuncTorchVmapMode, Batched, VmapMode, FuncTorchGradWrapper, PythonTLSSnapshot, FuncTorchDynamicLayerFrontMode, PreDispatch, PythonDispatcher].

CPU: registered at C:\actions-runner\_work\pytorch\pytorch\builder\windows\pytorch\build\build\aten\src\ATen\RegisterCPU.cpp:31188 [kernel]
Meta: registered at C:\actions-runner\_work\pytorch\pytorch\builder\windows\pytorch\build\build\aten\src\ATen\RegisterMeta.cpp:26829 [kernel]
QuantizedCPU: registered at C:\actions-runner\_work\pytorch\pytorch\builder\windows\pytorch\build\build\aten\src\ATen\RegisterQuantizedCPU.cpp:944 [kernel]
BackendSelect: registered at C:\actions-runner\_work\pytorch\pytorch\builder\windows\pytorch\build\build\aten\src\ATen\RegisterBackendSelect.cpp:742 [kernel]
Python: registered at C:\actions-runner\_work\pytorch\pytorch\builder\windows\pytorch\aten\src\ATen\core\PythonFallbackKernel.cpp:153 [backend fallback]
FuncTorchDynamicLayerBackMode: registered at C:\actions-runner\_work\pytorch\pytorch\builder\windows\pytorch\aten\src\ATen\functorch\DynamicLayer.cpp:498 [backend fallback]
Functionalize: registered at C:\actions-runner\_work\pytorch\pytorch\builder\windows\pytorch\aten\src\ATen\FunctionalizeFallbackKernel.cpp:290 [backend fallback]
Named: registered at C:\actions-runner\_work\pytorch\pytorch\builder\windows\pytorch\aten\src\ATen\core\NamedRegistrations.cpp:7 [backend fallback]
Conjugate: fallthrough registered at C:\actions-runner\_work\pytorch\pytorch\builder\windows\pytorch\aten\src\ATen\ConjugateFallback.cpp:21 [kernel]
Negative: fallthrough registered at C:\actions-runner\_work\pytorch\pytorch\builder\windows\pytorch\aten\src\ATen\native\NegateFallback.cpp:23 [kernel]
ZeroTensor: fallthrough registered at C:\actions-runner\_work\pytorch\pytorch\builder\windows\pytorch\aten\src\ATen\ZeroTensorFallback.cpp:90 [kernel]
ADInplaceOrView: fallthrough registered at C:\actions-runner\_work\pytorch\pytorch\builder\windows\pytorch\aten\src\ATen\core\VariableFallbackKernel.cpp:86 [backend fallback]
AutogradOther: registered at C:\actions-runner\_work\pytorch\pytorch\builder\windows\pytorch\torch\csrc\autograd\generated\VariableType_2.cpp:18694 [autograd kernel]
AutogradCPU: registered at C:\actions-runner\_work\pytorch\pytorch\builder\windows\pytorch\torch\csrc\autograd\generated\VariableType_2.cpp:18694 [autograd kernel]
AutogradCUDA: registered at C:\actions-runner\_work\pytorch\pytorch\builder\windows\pytorch\torch\csrc\autograd\generated\VariableType_2.cpp:18694 [autograd kernel]
AutogradHIP: registered at C:\actions-runner\_work\pytorch\pytorch\builder\windows\pytorch\torch\csrc\autograd\generated\VariableType_2.cpp:18694 [autograd kernel]
AutogradXLA: registered at C:\actions-runner\_work\pytorch\pytorch\builder\windows\pytorch\torch\csrc\autograd\generated\VariableType_2.cpp:18694 [autograd kernel]
AutogradMPS: registered at C:\actions-runner\_work\pytorch\pytorch\builder\windows\pytorch\torch\csrc\autograd\generated\VariableType_2.cpp:18694 [autograd kernel]
AutogradIPU: registered at C:\actions-runner\_work\pytorch\pytorch\builder\windows\pytorch\torch\csrc\autograd\generated\VariableType_2.cpp:18694 [autograd kernel]
AutogradXPU: registered at C:\actions-runner\_work\pytorch\pytorch\builder\windows\pytorch\torch\csrc\autograd\generated\VariableType_2.cpp:18694 [autograd kernel]
AutogradHPU: registered at C:\actions-runner\_work\pytorch\pytorch\builder\windows\pytorch\torch\csrc\autograd\generated\VariableType_2.cpp:18694 [autograd kernel]
AutogradVE: registered at C:\actions-runner\_work\pytorch\pytorch\builder\windows\pytorch\torch\csrc\autograd\generated\VariableType_2.cpp:18694 [autograd kernel]
AutogradLazy: registered at C:\actions-runner\_work\pytorch\pytorch\builder\windows\pytorch\torch\csrc\autograd\generated\VariableType_2.cpp:18694 [autograd kernel]
AutogradMTIA: registered at C:\actions-runner\_work\pytorch\pytorch\builder\windows\pytorch\torch\csrc\autograd\generated\VariableType_2.cpp:18694 [autograd kernel]
AutogradPrivateUse1: registered at C:\actions-runner\_work\pytorch\pytorch\builder\windows\pytorch\torch\csrc\autograd\generated\VariableType_2.cpp:18694 [autograd kernel]
AutogradPrivateUse2: registered at C:\actions-runner\_work\pytorch\pytorch\builder\windows\pytorch\torch\csrc\autograd\generated\VariableType_2.cpp:18694 [autograd kernel]
AutogradPrivateUse3: registered at C:\actions-runner\_work\pytorch\pytorch\builder\windows\pytorch\torch\csrc\autograd\generated\VariableType_2.cpp:18694 [autograd kernel]
AutogradMeta: registered at C:\actions-runner\_work\pytorch\pytorch\builder\windows\pytorch\torch\csrc\autograd\generated\VariableType_2.cpp:18694 [autograd kernel]
AutogradNestedTensor: registered at C:\actions-runner\_work\pytorch\pytorch\builder\windows\pytorch\torch\csrc\autograd\generated\VariableType_2.cpp:18694 [autograd kernel]
Tracer: registered at C:\actions-runner\_work\pytorch\pytorch\builder\windows\pytorch\torch\csrc\autograd\generated\TraceType_2.cpp:17079 [kernel]
AutocastCPU: fallthrough registered at C:\actions-runner\_work\pytorch\pytorch\builder\windows\pytorch\aten\src\ATen\autocast_mode.cpp:382 [backend fallback]
AutocastCUDA: fallthrough registered at C:\actions-runner\_work\pytorch\pytorch\builder\windows\pytorch\aten\src\ATen\autocast_mode.cpp:249 [backend fallback]
FuncTorchBatched: registered at C:\actions-runner\_work\pytorch\pytorch\builder\windows\pytorch\aten\src\ATen\functorch\LegacyBatchingRegistrations.cpp:710 [backend fallback]
FuncTorchVmapMode: fallthrough registered at C:\actions-runner\_work\pytorch\pytorch\builder\windows\pytorch\aten\src\ATen\functorch\VmapModeRegistrations.cpp:28 [backend fallback]
Batched: registered at C:\actions-runner\_work\pytorch\pytorch\builder\windows\pytorch\aten\src\ATen\LegacyBatchingRegistrations.cpp:1075 [backend fallback]
VmapMode: fallthrough registered at C:\actions-runner\_work\pytorch\pytorch\builder\windows\pytorch\aten\src\ATen\VmapModeRegistrations.cpp:33 [backend fallback]
FuncTorchGradWrapper: registered at C:\actions-runner\_work\pytorch\pytorch\builder\windows\pytorch\aten\src\ATen\functorch\TensorWrapper.cpp:203 [backend fallback]
PythonTLSSnapshot: registered at C:\actions-runner\_work\pytorch\pytorch\builder\windows\pytorch\aten\src\ATen\core\PythonFallbackKernel.cpp:161 [backend fallback]
FuncTorchDynamicLayerFrontMode: registered at C:\actions-runner\_work\pytorch\pytorch\builder\windows\pytorch\aten\src\ATen\functorch\DynamicLayer.cpp:494 [backend fallback]
PreDispatch: registered at C:\actions-runner\_work\pytorch\pytorch\builder\windows\pytorch\aten\src\ATen\core\PythonFallbackKernel.cpp:165 [backend fallback]
PythonDispatcher: registered at C:\actions-runner\_work\pytorch\pytorch\builder\windows\pytorch\aten\src\ATen\core\PythonFallbackKernel.cpp:157 [backend fallback]

    at ai.djl.pytorch.jni.PyTorchLibrary.moduleLoad(Native Method) ~[pytorch-engine-0.28.0-SNAPSHOT.jar:?]
    at ai.djl.pytorch.jni.JniUtils.loadModule(JniUtils.java:1742) ~[pytorch-engine-0.28.0-SNAPSHOT.jar:?]
    at ai.djl.pytorch.engine.PtModel.load(PtModel.java:98) ~[pytorch-engine-0.28.0-SNAPSHOT.jar:?]
    at ai.djl.repository.zoo.BaseModelLoader.loadModel(BaseModelLoader.java:166) ~[api-0.28.0-SNAPSHOT.jar:?]
    at ai.djl.repository.zoo.Criteria.loadModel(Criteria.java:172) ~[api-0.28.0-SNAPSHOT.jar:?]
frankfliu commented 7 months ago

Fix model import issue: https://github.com/deepjavalibrary/djl/pull/3098

PyTorch issue should not related to Translator changes. It might caused by jit trace and DJL pytorch are using different version.

david-sitsky commented 7 months ago

@frankfliu - while OnnxRuntime works, I have found it to be about three times slower compared to PyTorch, even when using optOption("ortDevice", "TensorRT"). I do realise there is some warmup time, but inferencing looks noticably slower for some reason.

Re: the Windows issue when using PyTorch 2.0.1 "error: extra text after expected end of number" I mentioned here: https://github.com/deepjavalibrary/djl/issues/3089#issuecomment-2058286436 won't be the case of "It might be caused by jit trace and DJL pytorch are using different version." I don't believe that is the case since I am dynamically downloading the models. I have zapped the .djl.ai directories before my runs and it doesn't seem to help.

Any ideas on how to resolve the issue for Windows? It seems really odd it is behaving differently to Linux here.

frankfliu commented 7 months ago

OnnxRuntime should be much faster than PyTorch. Are you sure you are using GPU? Which CUDA are you using? OnnxRuntime currently doesn't support CUDA 12.

Did you install TensorRT? which version are you using?

david-sitsky commented 7 months ago

@frankfliu - the GPU is definitely being used as confirmed by nvidia-smi. I can also see what looks like the appropriate versions of the key libraries loaded, checked via /proc/<pid>/maps:

7d9d6992f000-7d9d6a585000 rw-p 0d52e000 103:01 35380                     /usr/lib/x86_64-linux-gnu/libnvinfer.so.8.6.1
7d9e9da00000-7d9e9eef0000 r-xp 00000000 103:01 35489                     /usr/lib/x86_64-linux-gnu/libnvinfer_plugin.so.8.6.1
7d9c36600000-7d9c3bc9a000 r-xp 00000000 103:01 35427                     /usr/lib/x86_64-linux-gnu/libcudnn_ops_infer.so.8.9.7
7d9d6992f000-7d9d6a585000 rw-p 0d52e000 103:01 35380                     /usr/lib/x86_64-linux-gnu/libnvinfer.so.8.6.1
7d9e9da00000-7d9e9eef0000 r-xp 00000000 103:01 35489                     /usr/lib/x86_64-linux-gnu/libnvinfer_plugin.so.8.6.1
7d9efd400000-7d9efd6ab000 r-xp 00000000 103:01 35443                     /usr/lib/x86_64-linux-gnu/libnvonnxparser.so.8.6.1
7da01c400000-7da01c4a0000 r-xp 00000000 103:01 309281                    /usr/local/cuda-11.8/targets/x86_64-linux/lib/libcudart.so.11.8.89
7da034400000-7da034422000 r-xp 00000000 103:01 33761                     /usr/lib/x86_64-linux-gnu/libcudnn.so.8.9.7

I did see this warning message which is a concern (I didn't see it initially as it went to stderr rather than our log files):

2024-04-22 04:24:27.743131881 [W:onnxruntime:ort-java, tensorrt_execution_provider.h:83 log] [2024-04-22 04:24:27 WARNING] onnx2trt_utils.cpp:374: Your ONNX model has been generated with INT64 weights, while TensorRT does not natively support INT64. Attempting to cast down to INT32.

I'm not sure if this could be the cause of the performance issue. https://github.com/NVIDIA/TensorRT/issues/2542 has some interesting insights here, but it is not clear to me. Are there some tweaks we have to make to the DJL ONNX exporter to handle this? As a reminder, I created this ONNX model using this command:

python3 src/main/python/model_zoo_importer.py -m intfloat/multilingual-e5-small -f OnnxRuntime

Also - as a poor profiler, I ran jstack against one of my worker processes many times to get a sense of common stacks. When using PyTorch, I very rarely see code interacting with the GPU. For OnnxRuntime it it more common, but curiously I see these kind of stacks fairly often (almost never with PyTorch):

"main" #1 prio=5 os_prio=0 cpu=568148.06ms elapsed=1044.57s tid=0x00007da2c8035020 nid=0x3ad4 runnable  [0x00007da2ce116000]
   java.lang.Thread.State: RUNNABLE
        at ai.djl.pytorch.jni.PyTorchLibrary.torchSum(PyTorchLibrary.java)
        at ai.djl.pytorch.jni.JniUtils.sum(JniUtils.java:911)
        at ai.djl.pytorch.engine.PtNDArray.sum(PtNDArray.java:997)
        at ai.djl.pytorch.engine.PtNDArray.sum(PtNDArray.java:39)
        at ai.djl.ndarray.NDArray.sum(NDArray.java:2716)

and

"main" #1 prio=5 os_prio=0 cpu=592042.44ms elapsed=1102.06s tid=0x00007da2c8035020 nid=0x3ad4 runnable  [0x00007da2ce116000]
   java.lang.Thread.State: RUNNABLE
        at ai.djl.pytorch.jni.PyTorchLibrary.torchMul(PyTorchLibrary.java)
        at ai.djl.pytorch.jni.JniUtils.mul(JniUtils.java:746)
        at ai.djl.pytorch.engine.PtNDArray.mul(PtNDArray.java:566)
        at ai.djl.pytorch.engine.PtNDArray.mul(PtNDArray.java:39)

This is curious because this code is executed for processing the "mean pool" of the output from the model. Effectively this code (note I can switch between PyTorch and OnnxRuntime via a system property):

    protected NDArray processEmbedding(TranslatorContext ctx, NDList list)
    {
        NDArray embedding;
        if ("ortModel".equals(ctx.getModel().getNDManager().getName()))
        {
            embedding = list.get(0);
        }
        else
        {
            embedding = list.get("last_hidden_state");
        }
        Encoding encoding = (Encoding) ctx.getAttachment("Encoding");
        NDArray inputAttentionMask = ctx.getNDManager().create(encoding.getAttentionMask()).toType(DataType.FLOAT32, true);
        return meanPool(embedding, inputAttentionMask);
    }

    private static NDArray meanPool(NDArray embeddings, NDArray attentionMask)
    {
        long[] shape = embeddings.getShape().getShape();
        attentionMask = attentionMask.expandDims(-1).broadcast(shape);
        NDArray inputAttentionMaskSum = attentionMask.sum(AXIS);
        NDArray clamp = inputAttentionMaskSum.clip(1e-9f, 1e12f);
        NDArray prod = embeddings.mul(attentionMask);
        NDArray sum = prod.sum(AXIS);
        return sum.div(clamp);
    }

I am using OnnxRuntime with PyTorch as a "hybrid engine" as described by https://djl.ai/docs/hybrid_engine.html. Could this somehow be the cause of the slowdown?

frankfliu commented 7 months ago

By the way, with latest DJL, you can use this model with djl://ai.djl.huggingface.onnxruntime/intfloat/multilingual-e5-small

Which GPU arch are you using?

I run the benchmark this model on EC2 g4dn.8xlarge machine using djl-bench

mkdir ~/source
cd ~/source
git clone https://github.com/deepjavalibrary/djl-serving.git

docker run -it --rm --gpus all -v ~/source:/workspace -v 8080:8080 nvidia/cuda:11.8.0-cudnn8-devel-ubuntu20.04 bash
apt-get update
apt-get install openjdk-17-jdk-headless
cd /workspace/djl-serving/benchmark
./gradlew run --args="-e OnnxRuntime -u djl://ai.djl.huggingface.onnxruntime/intfloat/multilingual-e5-small -s (1,7)l,(1,7)l -c 1000 -t 2"

[INFO ] - Load OnnxRuntime (1.17.1) in 0.004 ms.
[INFO ] - Running MultithreadedBenchmark on: [gpu(0)].
[INFO ] - Multithreading inference with 2 threads.
Loading:     100% |========================================|
2024-04-22 05:38:45.573039816 [W:onnxruntime:, session_state.cc:1166 VerifyEachNodeIsAssignedToAnEp] Some nodes were not assigned to the preferred execution providers which may or may not have an negative impact on performance. e.g. ORT explicitly assigns shape related ops to CPU to improve perf.
2024-04-22 05:38:45.573066226 [W:onnxruntime:, session_state.cc:1168 VerifyEachNodeIsAssignedToAnEp] Rerunning with verbose output on a non-minimal build will show node assignments.
[INFO ] - Model multilingual-e5-small loaded in: 1381.179 ms.
[INFO ] - Warmup with 2 iteration ...
[INFO ] - Warmup latency, min: 2.345 ms, max: 45.388 ms
[INFO ] - Completed 1000 requests
[INFO ] - Inference result: [0.15173343, -0.03165054, -0.1703221 ...]
[INFO ] - Throughput: 820.34, completed 1000 iteration in 1219 ms.
[INFO ] - Model loading time: 1381.179 ms.
[INFO ] - total P50: 2.392 ms, P90: 2.506 ms, P99: 2.659 ms
[INFO ] - inference P50: 2.338 ms, P90: 2.419 ms, P99: 2.510 ms
[INFO ] - preprocess P50: 0.023 ms, P90: 0.060 ms, P99: 0.138 ms
[INFO ] - postprocess P50: 0.026 ms, P90: 0.042 ms, P99: 0.076 ms

You can see the model inference latency is only 2.3 ms, the I can get 820 TPS

david-sitsky commented 7 months ago

@frankfliu - I am running on G5 instances, so I believe they are A10s. Great to hear you have added support for dynamically loading ONNX models now via HuggingFace.

I'm curious if you can run your benchmarks for a longer period of time, and compare them to PyTorch directly to see if you can reproduce my issue.

Any ideas about the "Your ONNX model has been generated with INT64 weights..." warning message I recieved? Could this be the issue with my performance? Curious you didn't see that.

david-sitsky commented 7 months ago

I ran your benchmarks with both PyTorch and OnnxRuntime, and OnnxRuntime seems way faster. I didn't see the warning message either, so maybe I'll change my code to use the dynamic model download to see if that helps. Here are my benchmark results. I changed the -c parameter to 10000, not that I think that was needed:

root@27f843a2447e:/workspace/djl-serving/benchmark# ./gradlew run --args="-e PyTorch -u djl://ai.djl.huggingface.pytorch/intfloat/multilingual-e5-small -s (1,7)l,(1,7)l -c 10000 -t 2"

> Task :benchmark:run
[INFO ] - DJL will collect telemetry to help us better understand our users??? needs, diagnose issues, and deliver additional features. If you would like to learn more or opt-out please go to: https://docs.djl.ai/docs/telemetry.html for more information.
[WARN ] - No matching cuda flavor for linux-x86_64 found: cu118.
[INFO ] - PyTorch graph executor optimizer is enabled, this may impact your inference latency and throughput. See: https://docs.djl.ai/docs/development/inference_performance_optimization.html#graph-executor-optimization
[INFO ] - Number of inter-op threads is 1
[INFO ] - Number of intra-op threads is 1
[INFO ] - Load PyTorch (2.1.2) in 0.017 ms.
[INFO ] - Running MultithreadedBenchmark on: [cpu()].
[INFO ] - Multithreading inference with 2 threads.
Loading:     100% |========================================|
[INFO ] - Model multilingual-e5-small loaded in: 463.847 ms.
[INFO ] - Warmup with 2 iteration ...
[INFO ] - Warmup latency, min: 17.516 ms, max: 176.323 ms
[INFO ] - Completed 10000 requests
[INFO ] - Inference result: [0.007107233, 0.0042331107, 0.0052915057 ...]
[INFO ] - Throughput: 116.68, completed 10000 iteration in 85705 ms.
[INFO ] - Model loading time: 463.847 ms.
[INFO ] - total P50: 17.062 ms, P90: 17.333 ms, P99: 17.606 ms
[INFO ] - inference P50: 16.994 ms, P90: 17.248 ms, P99: 17.497 ms
[INFO ] - preprocess P50: 0.037 ms, P90: 0.042 ms, P99: 0.072 ms
[INFO ] - postprocess P50: 0.031 ms, P90: 0.049 ms, P99: 0.078 ms

BUILD SUCCESSFUL in 1m 29s
9 actionable tasks: 2 executed, 7 up-to-date
root@27f843a2447e:/workspace/djl-serving/benchmark# ./gradlew run --args="-e OnnxRuntime -u djl://ai.djl.huggingface.onnxruntime/intfloat/multilingual-e5-small -s (1,7)l,(1,7)l -c 10000 -t 2"

> Task :benchmark:run
[INFO ] - DJL will collect telemetry to help us better understand our users??? needs, diagnose issues, and deliver additional features. If you would like to learn more or opt-out please go to: https://docs.djl.ai/docs/telemetry.html for more information.
[WARN ] - Number of threads is less than GPU count, adjust to: 4
[INFO ] - Load OnnxRuntime (1.17.1) in 0.004 ms.
[INFO ] - Running MultithreadedBenchmark on: [gpu(0), gpu(1), gpu(2), gpu(3)].
[INFO ] - Multithreading inference with 4 threads.
Loading:     100% |========================================|
2024-04-23 10:43:01.081413654 [W:onnxruntime:, session_state.cc:1166 VerifyEachNodeIsAssignedToAnEp] Some nodes were not assigned to the preferred execution providers which may or may not have an negative impact on performance. e.g. ORT explicitly assigns shape related ops to CPU to improve perf.
2024-04-23 10:43:01.081445105 [W:onnxruntime:, session_state.cc:1168 VerifyEachNodeIsAssignedToAnEp] Rerunning with verbose output on a non-minimal build will show node assignments.
[INFO ] - Model multilingual-e5-small loaded in: 1108.532 ms.
Loading:     100% |========================================|
2024-04-23 10:43:02.235628758 [W:onnxruntime:, session_state.cc:1166 VerifyEachNodeIsAssignedToAnEp] Some nodes were not assigned to the preferred execution providers which may or may not have an negative impact on performance. e.g. ORT explicitly assigns shape related ops to CPU to improve perf.
2024-04-23 10:43:02.235655218 [W:onnxruntime:, session_state.cc:1168 VerifyEachNodeIsAssignedToAnEp] Rerunning with verbose output on a non-minimal build will show node assignments.
Loading:     100% |========================================|
2024-04-23 10:43:03.317442977 [W:onnxruntime:, session_state.cc:1166 VerifyEachNodeIsAssignedToAnEp] Some nodes were not assigned to the preferred execution providers which may or may not have an negative impact on performance. e.g. ORT explicitly assigns shape related ops to CPU to improve perf.
2024-04-23 10:43:03.317470968 [W:onnxruntime:, session_state.cc:1168 VerifyEachNodeIsAssignedToAnEp] Rerunning with verbose output on a non-minimal build will show node assignments.
Loading:     100% |========================================|
2024-04-23 10:43:04.400140448 [W:onnxruntime:, session_state.cc:1166 VerifyEachNodeIsAssignedToAnEp] Some nodes were not assigned to the preferred execution providers which may or may not have an negative impact on performance. e.g. ORT explicitly assigns shape related ops to CPU to improve perf.
2024-04-23 10:43:04.400167008 [W:onnxruntime:, session_state.cc:1168 VerifyEachNodeIsAssignedToAnEp] Rerunning with verbose output on a non-minimal build will show node assignments.
[INFO ] - Warmup with 2 iteration ...
[INFO ] - Warmup latency, min: 2.562 ms, max: 50.951 ms
[INFO ] - Completed 10000 requests
[INFO ] - Inference result: [0.15174039, -0.03164635, -0.17033705 ...]
[INFO ] - Throughput: 1855.29, completed 10000 iteration in 5390 ms.
[INFO ] - Model loading time: 1108.532 ms.
[INFO ] - total P50: 2.104 ms, P90: 2.275 ms, P99: 2.351 ms
[INFO ] - inference P50: 2.085 ms, P90: 2.259 ms, P99: 2.302 ms
[INFO ] - preprocess P50: 0.004 ms, P90: 0.012 ms, P99: 0.070 ms
[INFO ] - postprocess P50: 0.008 ms, P90: 0.020 ms, P99: 0.039 ms

BUILD SUCCESSFUL in 16s
9 actionable tasks: 2 executed, 7 up-to-date
root@27f843a2447e:/workspace/djl-serving/benchmark# 
david-sitsky commented 7 months ago

Hmmm, I can see the PyTorch benchmark didn't use the GPU for some reason..

frankfliu commented 7 months ago
No matching cuda flavor for linux-x86_64 found: cu118.

For Pytorch, you need to set PYTORCH_VERSION=1.13.1 to use cu118.

david-sitsky commented 7 months ago

Ah.. thanks. Now we see better results for PyTorch, ONNX is still quite a bit faster (3-4x):

[INFO ] - Load PyTorch (1.13.1) in 0.011 ms.
[INFO ] - Running MultithreadedBenchmark on: [gpu(0), gpu(1), gpu(2), gpu(3)].
[INFO ] - Multithreading inference with 4 threads.
Downloading: 100% |========================================|
Loading:     100% |========================================|
[INFO ] - Model multilingual-e5-small loaded in: 6627.692 ms.
Loading:     100% |========================================|
Loading:     100% |========================================|
Loading:     100% |========================================|
[INFO ] - Warmup with 2 iteration ...
[INFO ] - Warmup latency, min: 805.236 ms, max: 1657.596 ms
[INFO ] - Completed 10000 requests
[INFO ] - Inference result: [0.0071072537, 0.004233095, 0.005291494 ...]
[INFO ] - Throughput: 565.10, completed 10000 iteration in 17696 ms.
[INFO ] - Model loading time: 6627.692 ms.
[INFO ] - total P50: 7.055 ms, P90: 7.755 ms, P99: 8.330 ms
[INFO ] - inference P50: 6.974 ms, P90: 7.675 ms, P99: 8.247 ms
[INFO ] - preprocess P50: 0.036 ms, P90: 0.040 ms, P99: 0.057 ms
[INFO ] - postprocess P50: 0.043 ms, P90: 0.049 ms, P99: 0.071 ms
david-sitsky commented 7 months ago

@frankfliu - for my program, despite using the downloaded version of the ONNX models, performance is sadly the same. Overall my program (which does a lot more than just inferencing) is 3x slower compared to using the PyTorch engine.

Returning to the jstacks, I believe the time difference is due to post-processing in the Translator after inferencing has happened, and I think the issue is ONNX has to convert/allocate new PyTorch NDArrays, copy the data before it can run the operations due to the "hybrid engine" approach, via NDArrayAdapter. Where-as for PyTorch, it can very quickly compute split(), sum() and prod() with the existing NDArray by calling JniUtils.sum() and JniUtils.mul() immediately.

So I don't think ONNX inferencing is the problem.. it is the post-processing code, and the need to copy/allocate PyTorch arrays to do that work that seems to be very slow.

As a reminder.. here are some example jstacks when using ONNX which are not seen with PyTorch as these post-processing sections run very quickly.

Any ideas on what can be done to speed up ONNX Translator post-processing?


"main" #1 prio=5 os_prio=0 cpu=169480.70ms elapsed=452.04s tid=0x0000713158035020 nid=0x1b99 runnable  [0x000071315ef16000]
   java.lang.Thread.State: RUNNABLE
        at jdk.internal.misc.Unsafe.setMemory0(Unsafe.java)
        at jdk.internal.misc.Unsafe.setMemory(Unsafe.java:742)
        at jdk.internal.misc.Unsafe.setMemory(Unsafe.java:753)
        at java.nio.DirectByteBuffer.<init>(DirectByteBuffer.java:130)
        at java.nio.ByteBuffer.allocateDirect(ByteBuffer.java:332)
        at ai.djl.pytorch.engine.PtNDManager.allocateDirect(PtNDManager.java:46)
        at ai.djl.pytorch.engine.PtNDManager.create(PtNDManager.java:75)
        at ai.djl.pytorch.engine.PtNDManager.from(PtNDManager.java:55)
        at ai.djl.pytorch.engine.PtNDManager.from(PtNDManager.java:31)
        at ai.djl.ndarray.NDArrayAdapter.getAlternativeArray(NDArrayAdapter.java:1315)
        at ai.djl.ndarray.NDArrayAdapter.split(NDArrayAdapter.java:876)
        at ai.djl.ndarray.NDArray.split(NDArray.java:3174)
        at ai.djl.translate.StackBatchifier.unbatchify(StackBatchifier.java:118)

"main" #1 prio=5 os_prio=0 cpu=159754.36ms elapsed=433.20s tid=0x0000713158035020 nid=0x1b99 runnable  [0x000071315ef16000]
   java.lang.Thread.State: RUNNABLE
        at ai.djl.pytorch.jni.PyTorchLibrary.torchSum(PyTorchLibrary.java)
        at ai.djl.pytorch.jni.JniUtils.sum(JniUtils.java:911)
        at ai.djl.pytorch.engine.PtNDArray.sum(PtNDArray.java:997)
        at ai.djl.pytorch.engine.PtNDArray.sum(PtNDArray.java:39)
        at ai.djl.ndarray.NDArray.sum(NDArray.java:2716)
        ...Translator.meanPool()

"main" #1 prio=5 os_prio=0 cpu=175694.74ms elapsed=468.02s tid=0x0000713158035020 nid=0x1b99 runnable  [0x000071315ef16000]
   java.lang.Thread.State: RUNNABLE
        at ai.djl.pytorch.jni.PyTorchLibrary.torchMul(PyTorchLibrary.java)
        at ai.djl.pytorch.jni.JniUtils.mul(JniUtils.java:746)
        at ai.djl.pytorch.engine.PtNDArray.mul(PtNDArray.java:566)
        at ai.djl.pytorch.engine.PtNDArray.mul(PtNDArray.java:39)
        ...Translator.meanPool()
...
frankfliu commented 7 months ago

@david-sitsky

You can use Metrics class to collect pre/post processing time vs inference time. It will give me a clear picture where is the bottleneck.

If post process is the bottleneck, you can make pooling and normalize as part of the model, and convert to onnx with post processing.

It's possible to avoid memory copy for Hybrid engine. There is a private method in in OnnxRuntime, we could use reflection to invoke it (but I would rather avoid it if possible)

david-sitsky commented 7 months ago

@frankfliu - here are the results of the Metrics run. As suspected, it looks like postprocess is the culprit for OnnxRuntime. I'd like to keep (ideally) the same translator post-processing code regardless of the engine I use.

Any ideas on how we can speed up OnnxRuntime post-processing? I understand it is not ideal to use the private method.. but the overheads seem very large at the moment. Thanks for all your help in this.

PyTorch:

Metrics for intfloat/multilingual-e5-small
total P50: 9.506 ms, P90: 24.511 ms, P99: 329.438 ms
inference P50: 4.243 ms, P90: 5.519 ms, P99: 323.447 ms
preprocess P50: 0.331 ms, P90: 1.999 ms, P99: 4.538 ms
postprocess P50: 1.831 ms, P90: 5.165 ms, P99: 9.329 ms

OnnxRuntime:

Metrics for intfloat/multilingual-e5-small
total P50: 84.301 ms, P90: 338.719 ms, P99: 609.099 ms
inference P50: 4.396 ms, P90: 9.542 ms, P99: 22.122 ms
preprocess P50: 0.176 ms, P90: 0.254 ms, P99: 2.514 ms
postprocess P50: 68.294 ms, P90: 103.183 ms, P99: 121.825 ms
frankfliu commented 7 months ago

For OnnxRuntime, here is what might happened:

  1. Postprocessing is using PyTorch CPU, cpu operations are much slower than GPU
  2. unnecessary memory copy from GPU -> CPU -> Java heap -> PyTorch -> Java heap
david-sitsky commented 7 months ago

Thanks @frankfliu. Given what my jstacks show (albeit there are caveats with that), my guess is PyTorch CPU is being used in post-processing. Any suggestions on how to fix that?

frankfliu commented 7 months ago

@david-sitsky

Seems a testing code get merged into master, please try again after this PR merged: https://github.com/deepjavalibrary/djl/pull/3122

david-sitsky commented 7 months ago

@frankfliu - that change of yours looks good, but I am not using TextEmbeddingTranslator, but my own, which takes tokenised input directly as input (partly to handle large documents and ensure the numbers of tokens passed doesn't exceede 512). So I don't think that will solve the issue sadly:

/**
 * An embeddings translator which takes encoded input rather than
 * the raw string.
 */
class EncodingEmbeddingTranslator implements Translator<Encoding, float[]>
{
    private static final int[] AXIS = {0};

    @Override
    @CanIgnoreReturnValue
    public NDList processInput(TranslatorContext ctx, Encoding encoding)
    {
        ctx.setAttachment("Encoding", encoding);
        return encoding.toNDList(ctx.getNDManager(), false);
    }

    @Override
    public NDList batchProcessInput(TranslatorContext ctx, List<Encoding> encodings)
    {
        NDManager manager = ctx.getNDManager();
        ctx.setAttachment("Encodings", encodings);
        NDList[] batch = new NDList[encodings.size()];
        for (int i = 0; i < encodings.size(); i++)
        {
            batch[i] = encodings.get(i).toNDList(manager, false);
        }
        return getBatchifier().batchify(batch);
    }

    @Override
    public float[] processOutput(TranslatorContext ctx, NDList list) {
        Encoding encoding = (Encoding) ctx.getAttachment("Encoding");
        NDArray embeddings = processEmbedding(ctx, list, encoding);
        embeddings = embeddings.normalize(2, 0);
        return embeddings.toFloatArray();
    }

    @Override
    public List<float[]> batchProcessOutput(TranslatorContext ctx, NDList list) {
        NDList[] batch = getBatchifier().unbatchify(list);
        List<Encoding> encodings = (List<Encoding>) ctx.getAttachment("Encodings");
        List<float[]> ret = new ArrayList<>(batch.length);
        for (int i = 0; i < batch.length; ++i) {
            NDArray array = processEmbedding(ctx, batch[i], encodings.get(i));
            array = array.normalize(2, 0);
            ret.add(array.toFloatArray());
        }
        return ret;
    }

    /**
     * Process the embeddings.
     *
     * @param ctx the translator context.
     * @param list the embeddings.
     * @param encoding the encoding.
     * @return the updated embeddings.
     */
    protected NDArray processEmbedding(TranslatorContext ctx, NDList list, Encoding encoding)
    {
        NDArray embedding;
        if ("ortModel".equals(ctx.getModel().getNDManager().getName()))
        {
            embedding = list.get(0);
        }
        else
        {
            embedding = list.get("last_hidden_state");
        }
        NDArray inputAttentionMask = ctx.getNDManager().create(encoding.getAttentionMask()).toType(DataType.FLOAT32, true);
        return meanPool(embedding, inputAttentionMask);
    }

    /**
     * Computes the mean pool.
     *
     * @param embeddings the embeddings.
     * @param attentionMask the attention mask.
     * @return the mean pool.
     */
    private static NDArray meanPool(NDArray embeddings, NDArray attentionMask)
    {
        long[] shape = embeddings.getShape().getShape();
        attentionMask = attentionMask.expandDims(-1).broadcast(shape);
        NDArray inputAttentionMaskSum = attentionMask.sum(AXIS);
        NDArray clamp = inputAttentionMaskSum.clip(1e-9f, 1e12f);
        NDArray prod = embeddings.mul(attentionMask);
        NDArray sum = prod.sum(AXIS);
        return sum.div(clamp);
    }
}
david-sitsky commented 6 months ago

@frankfliu - I still believe the issue is the alternative manager for OrtEngine (in my case PtEngine) is not using the GPU, hence the post operations are slow. I had a quick look in this area, and noticed this code:

    protected BaseNDManager(NDManager parent, Device device) {
        this.parent = parent;
        this.device = device == null ? defaultDevice() : device;
        resources = new ConcurrentHashMap<>();
        tempResources = new ConcurrentHashMap<>();
        uid = UUID.randomUUID().toString();
        Engine engine = getEngine().getAlternativeEngine();
        if (engine != null) {
            alternativeManager = engine.newBaseManager(Device.cpu());
        }
    }

Won't this force the alternative manager PtEngine to use the CPU, and thus post processing will be slow? Shouldn't this code just be alternativeManager = engine.newBaseManager(device);?

frankfliu commented 6 months ago

@david-sitsky

You are right, this is an issue. The issue is the alternative engine may not support GPU, I think it should be:

alternativeManager = engine.newBaseManager();
david-sitsky commented 6 months ago

@frankfliu - In the ideal case, where the alternative engine does support GPUs, wouldn't we want it to use the same GPU, so any downstream operations will avoid any potential copies? Is there a way we can catch an exception to handle those engines which don't support GPUs and then just use the version without arguments?

In my case, I will sometimes run multiple processes, each dedicated to a specific GPU, and I'd ideally want the operations for each process to be pinned to the right GPU.

david-sitsky commented 6 months ago

I'll create a PR with the appropriate exception handling so you can look at it. Indeed without it some tests will fail, although this is more likely a setup issue on my box.

java.lang.UnsatisfiedLinkError: 'boolean ai.djl.engine.rust.RustLibrary.isCudaAvailable()'
    at ai.djl.engine.rust.RustLibrary.isCudaAvailable(Native Method)
    at ai.djl.engine.rust.RsEngine.hasCapability(RsEngine.java:73)
    at ai.djl.engine.Engine.defaultDevice(Engine.java:215)
    at ai.djl.ndarray.BaseNDManager.defaultDevice(BaseNDManager.java:72)
    at ai.djl.ndarray.BaseNDManager.<init>(BaseNDManager.java:59)
    at ai.djl.engine.rust.RsNDManager.<init>(RsNDManager.java:34)
    at ai.djl.engine.rust.RsNDManager.<init>(RsNDManager.java:29)
    at ai.djl.engine.rust.RsNDManager$SystemManager.<init>(RsNDManager.java:263)
    at ai.djl.engine.rust.RsNDManager.<clinit>(RsNDManager.java:31)
    at ai.djl.engine.rust.RsEngine.newBaseManager(RsEngine.java:87)
    at ai.djl.ndarray.NDManager.newBaseManager(NDManager.java:140)
    at ai.djl.engine.rust.NDArrayTests.testComparisonOp(NDArrayTests.java:121)
david-sitsky commented 6 months ago

My change now makes OnnxRuntime run almost the same time as PyTorch. I've created a PR here: https://github.com/deepjavalibrary/djl/pull/3138.