deepjavalibrary / djl

An Engine-Agnostic Deep Learning Framework in Java
https://djl.ai
Apache License 2.0
4.07k stars 650 forks source link

TorchScript concatenating rows of tensors not possible #2326

Closed natkramarz closed 1 year ago

natkramarz commented 1 year ago

Description

TorchScript cannot concatenate tensors of different sizes.

Expected Behavior

for

 category = torch.zeros((1, 18))
 input = torch.zeros((1, 59))
 hidden = torch.zeros((1, 128))

torch.cat([category, input, hidden], 1) should produce tensor of shape (1, 18 + 59 + 128)

Error Message

ai.djl.translate.TranslateException: ai.djl.engine.EngineException: The following operation failed in the TorchScript interpreter.
Traceback of TorchScript, serialized code (most recent call last):
  File "code/__torch__/___torch_mangle_739.py", line 20, in forward
    i2o = self.i2o
    i2h = self.i2h
    input0 = torch.cat([category, input, hidden], 1)
             ~~~~~~~~~ <--- HERE
Traceback of TorchScript, original code (most recent call last):
/var/folders/1k/bkzbwgm17v75b8nl67y09l000000gq/T/ipykernel_88756/992339029.py(18): forward
/opt/miniconda3/envs/ml/lib/python3.10/site-packages/torch/nn/modules/module.py(1172): _slow_forward
/opt/miniconda3/envs/ml/lib/python3.10/site-packages/torch/nn/modules/module.py(1188): _call_impl
/opt/miniconda3/envs/ml/lib/python3.10/site-packages/torch/jit/_trace.py(957): trace_module
/opt/miniconda3/envs/ml/lib/python3.10/site-packages/torch/jit/_trace.py(753): trace
/var/folders/1k/bkzbwgm17v75b8nl67y09l000000gq/T/ipykernel_88756/143577973.py(8): sample
/var/folders/1k/bkzbwgm17v75b8nl67y09l000000gq/T/ipykernel_88756/143577973.py(32): samples
/var/folders/1k/bkzbwgm17v75b8nl67y09l000000gq/T/ipykernel_88756/143577973.py(34): <module>
/opt/miniconda3/envs/ml/lib/python3.10/site-packages/IPython/core/interactiveshell.py(3430): run_code
/opt/miniconda3/envs/ml/lib/python3.10/site-packages/IPython/core/interactiveshell.py(3341): run_ast_nodes
/opt/miniconda3/envs/ml/lib/python3.10/site-packages/IPython/core/interactiveshell.py(3168): run_cell_async
/opt/miniconda3/envs/ml/lib/python3.10/site-packages/IPython/core/async_helpers.py(129): _pseudo_sync_runner
/opt/miniconda3/envs/ml/lib/python3.10/site-packages/IPython/core/interactiveshell.py(2970): _run_cell
/opt/miniconda3/envs/ml/lib/python3.10/site-packages/IPython/core/interactiveshell.py(2941): run_cell
/opt/miniconda3/envs/ml/lib/python3.10/site-packages/ipykernel/zmqshell.py(528): run_cell
/opt/miniconda3/envs/ml/lib/python3.10/site-packages/ipykernel/ipkernel.py(352): do_execute
/opt/miniconda3/envs/ml/lib/python3.10/site-packages/ipykernel/kernelbase.py(701): execute_request
/opt/miniconda3/envs/ml/lib/python3.10/site-packages/ipykernel/kernelbase.py(383): dispatch_shell
/opt/miniconda3/envs/ml/lib/python3.10/site-packages/ipykernel/kernelbase.py(496): process_one
/opt/miniconda3/envs/ml/lib/python3.10/site-packages/ipykernel/kernelbase.py(510): dispatch_queue
/opt/miniconda3/envs/ml/lib/python3.10/asyncio/events.py(80): _run
/opt/miniconda3/envs/ml/lib/python3.10/asyncio/base_events.py(1863): _run_once
/opt/miniconda3/envs/ml/lib/python3.10/asyncio/base_events.py(597): run_forever
/opt/miniconda3/envs/ml/lib/python3.10/site-packages/tornado/platform/asyncio.py(212): start
/opt/miniconda3/envs/ml/lib/python3.10/site-packages/ipykernel/kernelapp.py(702): start
/opt/miniconda3/envs/ml/lib/python3.10/site-packages/traitlets/config/application.py(980): launch_instance
/opt/miniconda3/envs/ml/lib/python3.10/site-packages/ipykernel_launcher.py(12): <module>
/opt/miniconda3/envs/ml/lib/python3.10/runpy.py(75): _run_code
/opt/miniconda3/envs/ml/lib/python3.10/runpy.py(191): _run_module_as_main
RuntimeError: Sizes of tensors must match except in dimension 1. Expected size 18 but got size 59 for tensor number 1 in the list.

How to Reproduce?

Steps to reproduce

  1. First create example model in Python
    
    import torch
    import torch.nn as nn

class RNN(nn.Module): def init(self, input_size, hidden_size, output_size): super(RNN, self).init()

def forward(self, category, input, hidden):
    input_combined = torch.zeros((1, category.shape[1] + input.shape[1] + hidden.shape[1]))
    return input, hidden

def initHidden(self):
    return torch.zeros(1, self.hidden_size)

model = RNN(n_letters, 128, n_letters) model.eval() def sample(category, start_letter='A'): with torch.no_grad(): category = category_tensor(category) input = input_tensor(start_letter) hidden = rnn.initHidden()

    category = torch.zeros((1, 18))
    input = torch.zeros((1, 59))
    hidden = torch.zeros((1, 128))
    example = (category, input, hidden)
    traced_script_module = torch.jit.trace(model, example)
    output_name = start_letter
    output, hidden = model(category, input, hidden)
    torch.jit.save(traced_script_module, "final_traced_rnn.pt")
    return output_name

sample('german', 'R')

2. I created folder 'final_traced_rnn', added to it 'final_traced_rnn.pt' file and zipped it

3. In my java application I run the test:
```java
public class ModelTest {

    RandomNameFromRNN randomNameFromRNN = new RandomNameFromRNN();

    @Test
    void test() {
        randomNameFromRNN.predict();
    }
}

Translator and RandomNameFromRNN classes:

public class RandomNameFromRNN {
    InputTranslator translator;
    Model model;

    public RandomNameFromRNN() {
        Path modelDir = null;
        translator = new InputTranslator();

        try {
            var criteria = Criteria.builder()
                    .setTypes(Float.class, Float.class)
                    .optTranslator(new InputTranslator())
                    .optModelPath(Paths.get(getClass().getClassLoader().getResource("final_traced_rnn.zip").toURI()))
                    .build();
            try {
                model = criteria.loadModel();
            } catch (IOException e) {
                throw new RuntimeException(e);
            } catch (ModelNotFoundException e) {
                throw new RuntimeException(e);
            } catch (MalformedModelException e) {
                throw new RuntimeException(e);
            }
        } catch (URISyntaxException e) {
            throw new RuntimeException(e);
        }
    }

    public String predict()  {
        var predictor = model.newPredictor(translator);
        try {
            predictor.predict(0.0F);
        } catch (TranslateException e) {
            throw new RuntimeException(e);
        }
        return "";
    }
}
public class InputTranslator implements Translator<Float, Float>{

    @Override
    public NDList processInput(TranslatorContext ctx, Float input) {
        NDManager manager = ctx.getNDManager();
        NDArray categoryArray = manager.zeros(new Shape(1, 18));
        NDArray inputArray = manager.zeros(new Shape(1, 59));
        NDArray hiddenArray = manager.zeros(new Shape(1, 128));
        return new NDList(categoryArray, inputArray, hiddenArray);
    }

    @Override
    public Float processOutput(TranslatorContext ctx, NDList list) {
        return 0.0F;
    }
};

What have you tried to solve it?

  1. I tried to pass different numbers as lengths of second dimension of tensors, but only passing the same number for all of the tensor works
  2. The model works in Python and importing it to different Python project also works.

Environment Info

Please run the command ./gradlew debugEnv from the root directory of DJL (if necessary, clone DJL first). It will output information about your system, environment, and installation that can help us debug your issue. Paste the output of the command below:

java.version: 17.0.5
os.arch: aarch64
running the command ended with 
----------------- Engines ---------------
DJL version: 0.21.0-SNAPSHOT
[DEBUG] - Using cache dir: /../djl.ai/mxnet/1.9.1-mkl-osx-aarch64
[DEBUG] - Loading mxnet library from: /../.djl.ai/mxnet/1.9.1-mkl-osx-aarch64/libmxnet.dylib
Exception in thread "main" ai.djl.engine.EngineException: Failed to load MXNet native library
frankfliu commented 1 year ago

@natkramarz Can you try the following:

    public NDList processInput(TranslatorContext ctx, Float input) {
        NDManager manager = ctx.getNDManager();
        NDArray categoryArray = manager.zeros(new Shape(18));
        NDArray inputArray = manager.zeros(new Shape(59));
        NDArray hiddenArray = manager.zeros(new Shape(128));
        return new NDList(categoryArray, inputArray, hiddenArray);
    }

DJL will automatically create batch dimension for you. If you want to manual batch the input, you can extend from NoBatchifyTranslator:

public class InputTranslator implements NoBatchifyTranslator<Float, Float> {
  ...
}
natkramarz commented 1 year ago

Thank you, both ways work :)