deepjavalibrary / djl

An Engine-Agnostic Deep Learning Framework in Java
https://djl.ai
Apache License 2.0
4.05k stars 648 forks source link

Error when inference pytorch encoder model from seq2seq_translation_tutorial pytorch #706

Closed cuongducle closed 3 years ago

cuongducle commented 3 years ago

Description

I trainned a encoder model in pytorch export the model as torch trace.I inferenced the model all right on the same computer in python. But when i load the model on java and run inference the result goes wrong.

class EncoderRNN(nn.Module):
    def __init__(self, hidden_size, embedding, n_layers=1, dropout=0):
        super(EncoderRNN, self).__init__()
        self.n_layers = n_layers
        self.hidden_size = hidden_size
        self.embedding = embedding
        self.gru = nn.GRU(hidden_size, hidden_size, n_layers,
                          dropout=(0 if n_layers == 1 else dropout), bidirectional=True)

    def forward(self, input_seq,input_lengths, hidden=None):
        input_seq = input_seq.transpose(0, 1).long()
        input_lengths = input_lengths.long()
        embedded = self.embedding(input_seq)
        packed = torch.nn.utils.rnn.pack_padded_sequence(embedded, input_lengths)
        outputs, hidden = self.gru(packed, hidden)
        outputs, _ = torch.nn.utils.rnn.pad_packed_sequence(outputs)
        outputs = outputs[:, :, :self.hidden_size] + outputs[:, : ,self.hidden_size:]
        return outputs, hidden

test_seq = torch.IntTensor(1,500).random_(0,143)
test_seq_length = torch.LongTensor([test_seq.size()[0]])
traced_encoder = torch.jit.trace(encoder, (test_seq,test_seq_length))

device = torch.device("cpu")
traced_encoder.to(device)
traced_encoder.eval()
traced_encoder.save("pt_trace/scripted_chatbot.pt")

Expected Behavior

When i inference using DJL.

String inputString = "cô giáo tôi là ~ ths #";
String normalizedString = Normalizer.normalize(inputString, Normalizer.Form.NFD);
ArrayList<Integer> result = new ArrayList<>();
for (char ch: normalizedString.toCharArray()) {
    String key = String.valueOf(ch);
    Object index = Word2Index.get(String.valueOf(ch));
    result.add(Integer.parseInt(index.toString()));
    }
result.add(2);
Shape shape = new Shape(result.size(),1);
NDArray ndArray = NDManager.newBaseManager().create(buildIntArray(result),shape);
ArrayList<Integer> length_arr = new ArrayList<>();
length_arr.add(27);
NDArray lengths = NDManager.newBaseManager().create(buildIntArray(length_arr));

public NDList processInput(NDArray input_batch, NDArray input_lengths) {
    NDList ndList = new NDList(input_batch,input_lengths);
    return ndList;
}

Criteria<NDList, NDList> criteria = Criteria.builder()
        .setTypes(NDList.class, NDList.class)
        .optTranslator(new NoopTranslator(Batchifier.STACK))
        .optModelUrls("file:///home/hironeo/Desktop/expand_abbr/pt_trace") 
        .optModelName("scripted_chatbot")
        .optProgress(new ProgressBar()).build();
ZooModel<NDList, NDList> model = ModelZoo.loadModel(criteria);
Predictor<NDList, NDList> predictor = model.newPredictor();
NDList a = predictor.predict(processInput(ndArray,lengths.get(0)));

Error Message

---------------------------------------------------------------------------
ai.djl.translate.TranslateException: ai.djl.engine.EngineException: The following operation failed in the TorchScript interpreter.
Traceback of TorchScript, serialized code (most recent call last):
  File "code/__torch__.py", line 17, in forward
    lengths0 = torch.to(lengths, torch.device("cpu"), 4, False, False, None)
    input0, batch_sizes = torch._pack_padded_sequence(_2, lengths0, False)
    _3, _4, = (_0).forward(batch_sizes, input0, )
               ~~~~~~~~~~~ <--- HERE
    max_seq_length = ops.prim.NumToTensor(torch.size(batch_sizes, 0))
    outputs, _5 = torch._pad_packed_sequence(_3, batch_sizes, False, 0., int(max_seq_length))
  File "code/__torch__/torch/nn/modules/rnn.py", line 42, in forward
    hx = torch.zeros([4, 1, 500], dtype=6, layout=None, device=torch.device("cpu"), pin_memory=False)
    _16 = [_15, _14, _13, _12, _11, _10, _9, _8, _7, _6, _5, _4, _3, _2, _1, _0]
    _17, _18 = torch.gru(input, batch_sizes, hx, _16, True, 2, 0.10000000000000001, False, True)
               ~~~~~~~~~ <--- HERE
    return (_17, _18)

Traceback of TorchScript, original code (most recent call last):
/home/hironeo/anaconda3/envs/bert/lib/python3.6/site-packages/torch/nn/modules/rnn.py(743): forward
/home/hironeo/anaconda3/envs/bert/lib/python3.6/site-packages/torch/nn/modules/module.py(709): _slow_forward
/home/hironeo/anaconda3/envs/bert/lib/python3.6/site-packages/torch/nn/modules/module.py(725): _call_impl
trace.py(36): forward
/home/hironeo/anaconda3/envs/bert/lib/python3.6/site-packages/torch/nn/modules/module.py(709): _slow_forward
/home/hironeo/anaconda3/envs/bert/lib/python3.6/site-packages/torch/nn/modules/module.py(725): _call_impl
/home/hironeo/anaconda3/envs/bert/lib/python3.6/site-packages/torch/jit/_trace.py(940): trace_module
/home/hironeo/anaconda3/envs/bert/lib/python3.6/site-packages/torch/jit/_trace.py(742): trace
trace.py(226): <module>
RuntimeError: The size of tensor a (500) must match the size of tensor b (1500) at non-singleton dimension 2

    at ai.djl.inference.Predictor.batchPredict(Predictor.java:180)
    at ai.djl.inference.Predictor.predict(Predictor.java:128)
    at .(#341:1)

Steps to reproduce

(Paste the commands you ran that produced the error.)

1.Export traced model. 2.Run the java code.

Environment Info

I using java jupyter notebook Ijava.


%maven ai.djl:api:0.9.0
%maven ai.djl.pytorch:pytorch-engine:0.9.0
%maven ai.djl.pytorch:pytorch-native-auto:1.7.1

the output of ./gradlew debugEnv :

Starting a Gradle Daemon (subsequent builds will be faster)

> Task :integration:debugEnv
[DEBUG] - cudart library not found.
[DEBUG] - Using cache dir: /home/hironeo/.djl.ai/mxnet
[DEBUG] - Loading mxnet library from: /home/hironeo/.djl.ai/mxnet/1.7.0-backport-mkl-linux-x86_64/libmxnet.so
[DEBUG] - Engine loaded from provider: MXNet
[DEBUG] - Found default engine: MXNet
----------- System Properties -----------
sun.cpu.isalist: 
sun.desktop: gnome
sun.io.unicode.encoding: UnicodeLittle
sun.cpu.endian: little
java.vendor.url.bug: http://bugreport.sun.com/bugreport/
file.separator: /
java.vendor: Private Build
sun.boot.class.path: /usr/lib/jvm/java-8-openjdk-amd64/jre/lib/resources.jar:/usr/lib/jvm/java-8-openjdk-amd64/jre/lib/rt.jar:/usr/lib/jvm/java-8-openjdk-amd64/jre/lib/sunrsasign.jar:/usr/lib/jvm/java-8-openjdk-amd64/jre/lib/jsse.jar:/usr/lib/jvm/java-8-openjdk-amd64/jre/lib/jce.jar:/usr/lib/jvm/java-8-openjdk-amd64/jre/lib/charsets.jar:/usr/lib/jvm/java-8-openjdk-amd64/jre/lib/jfr.jar:/usr/lib/jvm/java-8-openjdk-amd64/jre/classes
java.ext.dirs: /usr/lib/jvm/java-8-openjdk-amd64/jre/lib/ext:/usr/java/packages/lib/ext
java.version: 1.8.0_282
java.vm.info: mixed mode
awt.toolkit: sun.awt.X11.XToolkit
user.language: en
java.specification.vendor: Oracle Corporation
sun.java.command: ai.djl.integration.util.DebugEnvironment
java.home: /usr/lib/jvm/java-8-openjdk-amd64/jre
sun.arch.data.model: 64
java.vm.specification.version: 1.8
java.class.path: /home/hironeo/Desktop/djl/integration/build/classes/java/main:/home/hironeo/Desktop/djl/integration/build/resources/main:/home/hironeo/.gradle/caches/modules-2/files-2.1/commons-cli/commons-cli/1.4/c51c00206bb913cd8612b24abd9fa98ae89719b1/commons-cli-1.4.jar:/home/hironeo/.gradle/caches/modules-2/files-2.1/org.apache.logging.log4j/log4j-slf4j-impl/2.13.3/7cca27a921a18645139cf651c04b83b1a19cfd76/log4j-slf4j-impl-2.13.3.jar:/home/hironeo/Desktop/djl/basicdataset/build/libs/basicdataset-0.10.0-SNAPSHOT.jar:/home/hironeo/Desktop/djl/model-zoo/build/libs/model-zoo-0.10.0-SNAPSHOT.jar:/home/hironeo/Desktop/djl/testing/build/libs/testing-0.10.0-SNAPSHOT.jar:/home/hironeo/.gradle/caches/modules-2/files-2.1/org.testng/testng/7.1.0/b0bcea778fb2899aeb4014c558babea8833d180a/testng-7.1.0.jar:/home/hironeo/Desktop/djl/mxnet/mxnet-model-zoo/build/libs/mxnet-model-zoo-0.10.0-SNAPSHOT.jar:/home/hironeo/Desktop/djl/mxnet/mxnet-engine/build/libs/mxnet-engine-0.10.0-SNAPSHOT.jar:/home/hironeo/.gradle/caches/modules-2/files-2.1/ai.djl.mxnet/mxnet-native-auto/1.7.0-backport/ee5b368ef94c1fcec4ade4a6edacffb420fefce7/mxnet-native-auto-1.7.0-backport.jar:/home/hironeo/Desktop/djl/api/build/libs/api-0.10.0-SNAPSHOT.jar:/home/hironeo/.gradle/caches/modules-2/files-2.1/org.slf4j/slf4j-api/1.7.30/b5a4b6d16ab13e34a88fae84c35cd5d68cac922c/slf4j-api-1.7.30.jar:/home/hironeo/.gradle/caches/modules-2/files-2.1/org.apache.logging.log4j/log4j-core/2.13.3/4e857439fc4fe974d212adaaaa3b118b8b50e3ec/log4j-core-2.13.3.jar:/home/hironeo/.gradle/caches/modules-2/files-2.1/org.apache.logging.log4j/log4j-api/2.13.3/ec1508160b93d274b1add34419b897bae84c6ca9/log4j-api-2.13.3.jar:/home/hironeo/.gradle/caches/modules-2/files-2.1/org.apache.commons/commons-csv/1.8/37ca9a9aa2d4be2599e55506a6d3170dd7a3df4/commons-csv-1.8.jar:/home/hironeo/.gradle/caches/modules-2/files-2.1/com.beust/jcommander/1.72/6375e521c1e11d6563d4f25a07ce124ccf8cd171/jcommander-1.72.jar:/home/hironeo/.gradle/caches/modules-2/files-2.1/com.google.inject/guice/4.1.0/faf9ee8ac09eafd1128091426dd367a8c0085d55/guice-4.1.0-no_aop.jar:/home/hironeo/.gradle/caches/modules-2/files-2.1/org.yaml/snakeyaml/1.21/18775fdda48574784f40b47bf478ab0593f92e4d/snakeyaml-1.21.jar:/home/hironeo/.gradle/caches/modules-2/files-2.1/com.google.code.gson/gson/2.8.6/9180733b7df8542621dc12e21e87557e8c99b8cb/gson-2.8.6.jar:/home/hironeo/.gradle/caches/modules-2/files-2.1/net.java.dev.jna/jna/5.3.0/4654d1da02e4173ba7b64f7166378847db55448a/jna-5.3.0.jar:/home/hironeo/.gradle/caches/modules-2/files-2.1/org.apache.commons/commons-compress/1.20/b8df472b31e1f17c232d2ad78ceb1c84e00c641b/commons-compress-1.20.jar:/home/hironeo/.gradle/caches/modules-2/files-2.1/javax.inject/javax.inject/1/6975da39a7040257bd51d21a231b76c915872d38/javax.inject-1.jar:/home/hironeo/.gradle/caches/modules-2/files-2.1/aopalliance/aopalliance/1.0/235ba8b489512805ac13a8f9ea77a1ca5ebe3e8/aopalliance-1.0.jar:/home/hironeo/.gradle/caches/modules-2/files-2.1/com.google.guava/guava/19.0/6ce200f6b23222af3d8abb6b6459e6c44f4bb0e9/guava-19.0.jar
user.name: hironeo
ai.djl.logging.level: debug
file.encoding: UTF-8
java.specification.version: 1.8
java.awt.printerjob: sun.print.PSPrinterJob
user.timezone: Asia/Ho_Chi_Minh
user.home: /home/hironeo
library.jansi.path: /home/hironeo/.gradle/native/jansi/1.18/linux64
os.version: 5.8.0-43-generic
sun.management.compiler: HotSpot 64-Bit Tiered Compilers
java.specification.name: Java Platform API Specification
java.class.version: 52.0
org.gradle.internal.http.connectionTimeout: 60000
java.library.path: :/home/hironeo/hadoop/lib/native:/usr/java/packages/lib/amd64:/usr/lib/x86_64-linux-gnu/jni:/lib/x86_64-linux-gnu:/usr/lib/x86_64-linux-gnu:/usr/lib/jni:/lib:/usr/lib
jnidispatch.path: /home/hironeo/.cache/JNA/temp/jna2403741189233122078.tmp
org.gradle.internal.publish.checksums.insecure: true
sun.jnu.encoding: UTF-8
os.name: Linux
user.variant: 
java.vm.specification.vendor: Oracle Corporation
org.gradle.appname: gradlew
java.io.tmpdir: /tmp
line.separator: 

java.endorsed.dirs: /usr/lib/jvm/java-8-openjdk-amd64/jre/lib/endorsed
os.arch: amd64
java.awt.graphicsenv: sun.awt.X11GraphicsEnvironment
java.runtime.version: 1.8.0_282-8u282-b08-0ubuntu1~20.04-b08
java.vm.specification.name: Java Virtual Machine Specification
user.dir: /home/hironeo/Desktop/djl/integration
org.gradle.internal.http.socketTimeout: 120000
user.country: US
sun.java.launcher: SUN_STANDARD
sun.os.patch.level: unknown
jna.loaded: true
java.vm.name: OpenJDK 64-Bit Server VM
file.encoding.pkg: sun.io
path.separator: :
java.vm.vendor: Private Build
java.vendor.url: http://java.oracle.com/
sun.boot.library.path: /usr/lib/jvm/java-8-openjdk-amd64/jre/lib/amd64:/usr/lib/jvm/java-8-openjdk-amd64/jre/lib/amd64
java.vm.version: 25.282-b08
jna.platform.library.path: /usr/lib/x86_64-linux-gnu:/lib/x86_64-linux-gnu:/usr/lib64:/lib64:/usr/lib:/lib:/usr/lib/x86_64-linux-gnu/libfakeroot
java.runtime.name: OpenJDK Runtime Environment

--------- Environment Variables ---------
PATH: /home/hironeo/anaconda3/envs/bert/bin:/home/hironeo/anaconda3/condabin:/usr/local/sbin:/usr/local/bin:/usr/sbin:/usr/bin:/sbin:/bin:/usr/games:/usr/local/games:/snap/bin:/home/hironeo/hadoop/sbin:/home/hironeo/hadoop/bin:/home/hironeo/Android/Sdk/tools:/usr/lib/jvm/java-1.8.0-openjdk-amd64/bin:/usr/local/spark/3.0.1/bin
LC_MEASUREMENT: en_US.UTF-8
XAUTHORITY: /run/user/1000/gdm/Xauthority
INVOCATION_ID: 241e1bd0f75a4d78a46c1626d5e0d954
XMODIFIERS: @im=ibus
LC_TELEPHONE: en_US.UTF-8
XDG_DATA_DIRS: /usr/share/ubuntu:/usr/local/share/:/usr/share/:/var/lib/snapd/desktop
GDMSESSION: ubuntu
LC_TIME: en_US.UTF-8
CONDA_DEFAULT_ENV: bert
PAPERSIZE: letter
CONDA_PYTHON_EXE: /home/hironeo/anaconda3/bin/python
DBUS_SESSION_BUS_ADDRESS: unix:path=/run/user/1000/bus
ANDROID_HOME: /home/hironeo/Android/Sdk
CONDA_PREFIX: /home/hironeo/anaconda3/envs/bert
XDG_CURRENT_DESKTOP: ubuntu:GNOME
JOURNAL_STREAM: 8:41711
SSH_AGENT_PID: 1641
COLORTERM: truecolor
LD_LIBRARY_PATH: :/home/hironeo/hadoop/lib/native
LC_PAPER: en_US.UTF-8
SESSION_MANAGER: local/hironeo:@/tmp/.ICE-unix/1673,unix/hironeo:/tmp/.ICE-unix/1673
USERNAME: hironeo
LOGNAME: hironeo
PWD: /home/hironeo/Desktop/djl
MANAGERPID: 1425
IM_CONFIG_PHASE: 1
HADOOP_INSTALL: /home/hironeo/hadoop
GJS_DEBUG_TOPICS: JS ERROR;JS LOG
SHELL: /bin/bash
LESSOPEN: | /usr/bin/lesspipe %s
LC_ADDRESS: en_US.UTF-8
PROJ_LIB: /home/hironeo/anaconda3/envs/bert/share/proj
OLDPWD: /home/hironeo/Desktop/djl
GNOME_DESKTOP_SESSION_ID: this-is-deprecated
GNOME_TERMINAL_SCREEN: /org/gnome/Terminal/screen/57cdfbad_c81a_4c52_97d0_41d60bae5c7f
GTK_MODULES: gail:atk-bridge
HADOOP_HOME: /home/hironeo/hadoop
CONDA_PROMPT_MODIFIER: (bert) 
LS_COLORS: rs=0:di=01;34:ln=01;36:mh=00:pi=40;33:so=01;35:do=01;35:bd=40;33;01:cd=40;33;01:or=40;31;01:mi=00:su=37;41:sg=30;43:ca=30;41:tw=30;42:ow=34;42:st=37;44:ex=01;32:*.tar=01;31:*.tgz=01;31:*.arc=01;31:*.arj=01;31:*.taz=01;31:*.lha=01;31:*.lz4=01;31:*.lzh=01;31:*.lzma=01;31:*.tlz=01;31:*.txz=01;31:*.tzo=01;31:*.t7z=01;31:*.zip=01;31:*.z=01;31:*.dz=01;31:*.gz=01;31:*.lrz=01;31:*.lz=01;31:*.lzo=01;31:*.xz=01;31:*.zst=01;31:*.tzst=01;31:*.bz2=01;31:*.bz=01;31:*.tbz=01;31:*.tbz2=01;31:*.tz=01;31:*.deb=01;31:*.rpm=01;31:*.jar=01;31:*.war=01;31:*.ear=01;31:*.sar=01;31:*.rar=01;31:*.alz=01;31:*.ace=01;31:*.zoo=01;31:*.cpio=01;31:*.7z=01;31:*.rz=01;31:*.cab=01;31:*.wim=01;31:*.swm=01;31:*.dwm=01;31:*.esd=01;31:*.jpg=01;35:*.jpeg=01;35:*.mjpg=01;35:*.mjpeg=01;35:*.gif=01;35:*.bmp=01;35:*.pbm=01;35:*.pgm=01;35:*.ppm=01;35:*.tga=01;35:*.xbm=01;35:*.xpm=01;35:*.tif=01;35:*.tiff=01;35:*.png=01;35:*.svg=01;35:*.svgz=01;35:*.mng=01;35:*.pcx=01;35:*.mov=01;35:*.mpg=01;35:*.mpeg=01;35:*.m2v=01;35:*.mkv=01;35:*.webm=01;35:*.ogm=01;35:*.mp4=01;35:*.m4v=01;35:*.mp4v=01;35:*.vob=01;35:*.qt=01;35:*.nuv=01;35:*.wmv=01;35:*.asf=01;35:*.rm=01;35:*.rmvb=01;35:*.flc=01;35:*.avi=01;35:*.fli=01;35:*.flv=01;35:*.gl=01;35:*.dl=01;35:*.xcf=01;35:*.xwd=01;35:*.yuv=01;35:*.cgm=01;35:*.emf=01;35:*.ogv=01;35:*.ogx=01;35:*.aac=00;36:*.au=00;36:*.flac=00;36:*.m4a=00;36:*.mid=00;36:*.midi=00;36:*.mka=00;36:*.mp3=00;36:*.mpc=00;36:*.ogg=00;36:*.ra=00;36:*.wav=00;36:*.oga=00;36:*.opus=00;36:*.spx=00;36:*.xspf=00;36:
XDG_SESSION_DESKTOP: ubuntu
SHLVL: 1
LC_IDENTIFICATION: en_US.UTF-8
LESSCLOSE: /usr/bin/lesspipe %s %s
LC_MONETARY: en_US.UTF-8
QT_IM_MODULE: ibus
CONDA_EXE: /home/hironeo/anaconda3/bin/conda
JAVA_HOME: /usr/lib/jvm/java-1.8.0-openjdk-amd64
TERM: xterm-256color
XDG_CONFIG_DIRS: /etc/xdg/xdg-ubuntu:/etc/xdg
GNOME_TERMINAL_SERVICE: :1.3001
LANG: en_US.UTF-8
XDG_SESSION_TYPE: x11
DISPLAY: :1
SPARK_HOME: /usr/local/spark/3.0.1
_CE_M: 
YARN_HOME: /home/hironeo/hadoop
HADOOP_HDFS_HOME: /home/hironeo/hadoop
HADOOP_MAPRED_HOME: /home/hironeo/hadoop
HADOOP_COMMON_HOME: /home/hironeo/hadoop
LC_NAME: en_US.UTF-8
CONDA_SHLVL: 1
XDG_SESSION_CLASS: user
_: ./gradlew
HADOOP_OPTS: -Djava.library.path=/home/hironeo/hadoop/lib/native
GPG_AGENT_INFO: /run/user/1000/gnupg/S.gpg-agent:0:1
DESKTOP_SESSION: ubuntu
ANDROID_SDK_ROOT: /home/hironeo/Android/Sdk
USER: hironeo
XDG_MENU_PREFIX: gnome-
VTE_VERSION: 6003
QT_ACCESSIBILITY: 1
WINDOWPATH: 2
LC_NUMERIC: en_US.UTF-8
GJS_DEBUG_OUTPUT: stderr
SSH_AUTH_SOCK: /tmp/ssh-YLMLtWgoVkXr/agent.1581
_CE_CONDA: 
HADOOP_COMMON_LIB_NATIVE_DIR: /home/hironeo/hadoop/lib/native
GNOME_SHELL_SESSION_MODE: ubuntu
XDG_RUNTIME_DIR: /run/user/1000
HOME: /home/hironeo

-------------- Directories --------------
temp directory: /tmp
Engine cache directory: /home/hironeo/.djl.ai

------------------ CUDA -----------------
GPU Count: 0
Default Device: cpu()

----------------- Engines ---------------
Default Engine: MXNet
[DEBUG] - Using cache dir: /home/hironeo/.djl.ai/mxnet
MXNet:1.7.0, capabilities: [
        SIGNAL_HANDLER,
        LAPACK,
        CPU_SSE2,
        CPU_SSE3,
        OPENCV,
        CPU_SSE,
        CPU_AVX,
        F16C,
        BLAS_OPEN,
        CPU_SSE4_2,
        DIST_KVSTORE,
        CPU_SSE4_1,
        OPENMP,
        MKLDNN,
]
MXNet Library: /home/hironeo/.djl.ai/mxnet/1.7.0-backport-mkl-linux-x86_64/libmxnet.so

--------------- Hardware --------------
Available processors (cores): 4
Byte Order: LITTLE_ENDIAN
Free memory (bytes): 95577616
Maximum memory (bytes): 1819803648
Total memory available to JVM (bytes): 122683392
Heap committed: 122683392
Heap nonCommitted: 21037056
GCC: 
gcc (Ubuntu 9.3.0-17ubuntu1~20.04) 9.3.0
Copyright (C) 2019 Free Software Foundation, Inc.
This is free software; see the source for copying conditions.  There is NO
warranty; not even for MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.

Deprecated Gradle features were used in this build, making it incompatible with Gradle 7.0.
Use '--warning-mode all' to show the individual deprecation warnings.
See https://docs.gradle.org/6.7.1/userguide/command_line_interface.html#sec:command_line_warnings

BUILD SUCCESSFUL in 44s
25 actionable tasks: 1 executed, 24 up-to-date
lanking520 commented 3 years ago

@cuongducle Have you compared the input shape? The error message is saying there is a mismatch on your tensor size

If possible, you can print the NDArray before passing it to inference so we can see the shape of it. Apart from that, you can also share us a template model from our end to reproduce the case if necessary.

cuongducle commented 3 years ago

The input shape was wrong. Thank you.