jatinchowdhury18 / RTNeural

Real-time neural network inferencing
BSD 3-Clause "New" or "Revised" License
543 stars 57 forks source link

How to get output of the model when the output is an array of numbers instead of just single number? #105

Closed alireza1325 closed 2 months ago

alireza1325 commented 11 months ago

Hi,

In every example you have provided, networks just have a single number as output. How about if the model wants to produce a 1D vector of numbers as output? Could you please share any examples with array output? More specifically, I want to train the model in TensorFlow or Pytorch and load weights using RTNeural and run the model in real time.

Thanks, Alireza

jatinchowdhury18 commented 11 months ago

For networks that have multiple outputs, the model.getOutputs() method can be used to return a pointer to the output vector.

I don't have an example in this repository at the moment, but here's an example from another project. I can add an example when I've got a minute, but if you happen to have one ready first, that would be cool too :).

alireza1325 commented 11 months ago

Hi,

Thank you for your quick reply. I have a deep convolutional neural network which I defined using Keras as follows:

import tensorflow.keras as keras
from tensorflow.keras.layers import Input, Conv1D, PReLU, BatchNormalization
input_shape = (128, 1)

x = Input(shape=input_shape,name = "x")
conv1 = Conv1D(filters=12, kernel_size=65, strides=1, dilation_rate=1, activation=None, padding='valid',name = "conv1")(x)
PRelu1 = PReLU(alpha_initializer='glorot_uniform', shared_axes=[1],name = "PRelu1")(conv1)
bn1 = BatchNormalization(axis=2, momentum=0.0, epsilon=0.01, beta_initializer='random_normal', gamma_initializer='glorot_uniform', moving_mean_initializer="random_normal", moving_variance_initializer="ones",name = "bn1")(PRelu1)
conv2 = Conv1D(filters=8, kernel_size=33, strides=1, dilation_rate=1, activation=None, padding='valid',name = "conv2")(bn1)
PRelu2 = PReLU(alpha_initializer='glorot_uniform', shared_axes=[1],name = "PRelu2")(conv2)
bn2 = BatchNormalization(axis=2, momentum=0.0, epsilon=0.01, beta_initializer='random_normal', gamma_initializer='glorot_uniform', moving_mean_initializer="random_normal", moving_variance_initializer="ones",name = "bn2")(PRelu2)
conv3 = Conv1D(filters=4, kernel_size=13, strides=1, dilation_rate=1, activation=None, padding='valid',name = "conv3")(bn2)
PRelu3 = PReLU(alpha_initializer='glorot_uniform', shared_axes=[1],name = "PRelu3")(conv3)
bn3 = BatchNormalization(axis=2, momentum=0.0, epsilon=0.01, beta_initializer='random_normal', gamma_initializer='glorot_uniform', moving_mean_initializer="random_normal", moving_variance_initializer="ones",name = "bn3")(PRelu3)
conv4 = Conv1D(filters=1, kernel_size=5, strides=1, dilation_rate=1, activation="tanh", padding='valid',name = "conv4")(bn3)

model = keras.Model(inputs=x, outputs=conv4)
model.summary()

save_model(model, "C:/Users/Alireza/OneDrive - University of Manitoba/Experimental Setup/saved models/My_SP8_3.json")

The input to the network has 1 channel with a length of 128, and the output also has 1 channel with a length of 16. I want to create a static model and get the output to compare it with the output of the python model. I have the following code:

#include "RTNeural/RTNeural.h"
#include "tests/load_csv.hpp"
#include <filesystem>
#include <iostream>

namespace fs = std::filesystem;

std::string getModelFile(fs::path path)
{
    // get path of RTNeural root directory
    while ((--path.end())->string() != "RTNeural")
        path = path.parent_path();

    // get path of model file
    path.append("C:/Users/Alireza/OneDrive - University of Manitoba/Experimental Setup/saved models/My_SP8_3.json");

    return path.string();
}

int main()
{
    std::ifstream modelInputsFile{ "C:/Users/Alireza/OneDrive - University of Manitoba/Experimental Setup/Saved_data/input.csv" };
    std::vector<float> inputs = load_csv::loadFile<float>(modelInputsFile);
    std::cout << "Data with size =  " << inputs.size() <<" are loaded" << std::endl;

    //Mymodel modelt;
    //modelt.load_model();

    RTNeural::ModelT<float, 1, 1,
        RTNeural::Conv1DT<float, 1, 12, 65, 1, false>,
        RTNeural::PReLUActivationT<float, 12>,
        RTNeural::BatchNorm1DT<float, 12, true>,
        RTNeural::Conv1DT<float, 12, 8, 33, 1, false>,
        RTNeural::PReLUActivationT<float, 8>,
        RTNeural::BatchNorm1DT<float, 8, true>,
        RTNeural::Conv1DT<float, 8, 4, 13, 1, false>,
        RTNeural::PReLUActivationT<float, 4>,
        RTNeural::BatchNorm1DT<float, 4, true>,
        RTNeural::Conv1DT<float, 4, 1, 5, 1, false>,
        RTNeural::TanhActivationT<float, 1>> modelt;

    auto executablePath = fs::weakly_canonical(fs::path("G:/Master thesis control/Experimental setup/RT_ANC/RTNeural-main/RTNeural"));
    auto modelFilePath = getModelFile(executablePath);

    std::cout << "Loading model from path: " << modelFilePath << std::endl;
    std::ifstream jsonStream(modelFilePath, std::ifstream::binary);

    modelt.parseJson(jsonStream, true);

    //float testInput[128];
    //for (int i = 0; i < inputs.size(); i++)
    //{
    //    testInput[i] = inputs[i];
    //}

    std::vector<float> outputs(16, 0);

    modelt.reset();

    float input alignas(RTNEURAL_DEFAULT_ALIGNMENT)[128];
    std::copy(inputs.begin(), inputs.begin() +128, input);

    modelt.forward(input);

    std::copy(modelt.getOutputs(), modelt.getOutputs() + 16, outputs.begin());

    for (int i = 0; i < outputs.size(); i++)
    {
        std::cout<< outputs[i]<<" \t";
    }

    std::cout << " \n";

    return 0;

}

The model runs and produces outputs without any errors. However, the output significantly differs from the Python output. Could you please let me know where I made a mistake?

jatinchowdhury18 commented 11 months ago

No problem!

There's two spots where I see potential for some miscommunication between the Tensorflow model and the RTNeural model.

For reference, here's one of our test models that has a somewhat similar architecture as yours: https://github.com/jatinchowdhury18/RTNeural/blob/main/python/conv.py

alireza1325 commented 11 months ago

Thanks for your reply.

I changed Conv1D paddings to "causal" and removed the "axis" arguments in the BatchNormalization layer. Now my model looks like this:

import tensorflow.keras as keras
from tensorflow.keras.layers import Input, Conv1D, PReLU, BatchNormalization
input_shape = (128, 1)

x = Input(shape=input_shape,name = "x")
conv1 = Conv1D(filters=12, kernel_size=65, strides=1, dilation_rate=1, activation=None, padding='causal',name = "conv1")(x)
PRelu1 = PReLU(alpha_initializer='glorot_uniform', shared_axes=[1],name = "PRelu1")(conv1)
bn1 = BatchNormalization(momentum=0.0, epsilon=0.01, beta_initializer='random_normal', gamma_initializer='glorot_uniform', moving_mean_initializer="random_normal", moving_variance_initializer="ones",name = "bn1")(PRelu1)
conv2 = Conv1D(filters=8, kernel_size=33, strides=1, dilation_rate=1, activation=None, padding='causal',name = "conv2")(bn1)
PRelu2 = PReLU(alpha_initializer='glorot_uniform', shared_axes=[1],name = "PRelu2")(conv2)
bn2 = BatchNormalization(momentum=0.0, epsilon=0.01, beta_initializer='random_normal', gamma_initializer='glorot_uniform', moving_mean_initializer="random_normal", moving_variance_initializer="ones",name = "bn2")(PRelu2)
conv3 = Conv1D(filters=4, kernel_size=13, strides=1, dilation_rate=1, activation=None, padding='causal',name = "conv3")(bn2)
PRelu3 = PReLU(alpha_initializer='glorot_uniform', shared_axes=[1],name = "PRelu3")(conv3)
bn3 = BatchNormalization(momentum=0.0, epsilon=0.01, beta_initializer='random_normal', gamma_initializer='glorot_uniform', moving_mean_initializer="random_normal", moving_variance_initializer="ones",name = "bn3")(PRelu3)
conv4 = Conv1D(filters=1, kernel_size=5, strides=1, dilation_rate=1, activation="tanh", padding='causal',name = "conv4")(bn3)

model = keras.Model(inputs=x, outputs=conv4)
model.summary()

As you said, the output is now the same size as the input (array of 128 samples). In this case, the last 16 samples of the output are the same as the output of the previous model with padding "valid".

I also changed the main.cpp to have 128 samples as output, as follows:

 std::vector<float> outputs(128, 0);

    modelt.reset();

    float input alignas(RTNEURAL_DEFAULT_ALIGNMENT)[128];
    std::copy(inputs.begin(), inputs.begin() +128, input);

    modelt.forward(input);
    //const float * out = modelt.getOutputs();
    //std::copy(out, out + 16, outputs.begin());

    std::copy(modelt.getOutputs(), modelt.getOutputs() + 128, outputs.begin());

    for (int i = 0; i < outputs.size(); i++)
    {
        std::cout<< outputs[i]<<" \t";
    }

    std::cout << " \n";

This is the output from Keras in python:

-0.11623513 -0.11049354 -0.11022955 -0.10855089 -0.10794831 -0.10736051
 -0.10818011 -0.10837929 -0.10823102 -0.10866965 -0.10826164 -0.10885713
 -0.10866178 -0.10983667 -0.10855884 -0.10777259 -0.10861548 -0.10825714
 -0.10932127 -0.10672156 -0.10904931 -0.11185976 -0.1100981  -0.10470231
 -0.11241413 -0.10575584 -0.10655493 -0.10417628 -0.10324544 -0.11125914
 -0.10714537 -0.11015294 -0.12063145 -0.1104138  -0.11175089 -0.12377845
 -0.10761438 -0.10254148 -0.11289804 -0.12034426 -0.10112429 -0.11914127
 -0.12184981 -0.11031057 -0.10483442 -0.11558313 -0.114549   -0.10393181
 -0.11524037 -0.10497224 -0.11133191 -0.0982298  -0.1074122  -0.11729922
 -0.10720943 -0.11052676 -0.12311254 -0.10740596 -0.09775402 -0.10977318
 -0.10150715 -0.1345206  -0.11838888 -0.13533725 -0.12531397 -0.11149908
 -0.10591609 -0.11578616 -0.10213581 -0.11388975 -0.11689662 -0.10346538
 -0.08569536 -0.11187818 -0.12634481 -0.09311869 -0.11327026 -0.11414964
 -0.0949176  -0.0976704  -0.11621068 -0.09878229 -0.12567036 -0.11186855
 -0.12620872 -0.1171542  -0.10653546 -0.11260238 -0.13113216 -0.13738045
 -0.13606432 -0.13324404 -0.1118001  -0.09250157 -0.10471465 -0.1033114
 -0.09969168 -0.13346806 -0.09274402 -0.11522593 -0.09647407 -0.12134076
 -0.10664527 -0.12027869 -0.11426934 -0.11851404 -0.1114307  -0.10709568
 -0.1212228  -0.10251133 -0.1419652  -0.12830819 -0.12780745 -0.11792901
 -0.11354288 -0.09447688 -0.10965364 -0.1013163  -0.12378055 -0.10835326
 -0.10527924 -0.08881909 -0.10305189 -0.10518864 -0.11651971 -0.11518028
 -0.11890033 -0.10831069

And this is the output of the RTNeural in C++:

# dimensions: 1
Layer: conv1d
  Dims: 12
Layer: prelu
  Dims: 12
Layer: batchnorm
  Dims: 12
Layer: conv1d
  Dims: 8
Layer: prelu
  Dims: 8
Layer: batchnorm
  Dims: 8
Layer: conv1d
  Dims: 4
Layer: prelu
  Dims: 4
Layer: batchnorm
  Dims: 4
Layer: conv1d
  Dims: 1
  activation: tanh
-0.116235 -1.07374e+08 -1.07374e+08 -1.07374e+08 -0.116235 0 0 0 -0.116763 -1.07374e+08 -1.07374e+08 -1.07374e+08 -0.0153737 -0.0141386 -0.00170327 -0.0243574 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 -0.0153737 -0.0141386 -0.00170327 -0.0243574 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1.4013e-45 0 5.60519e-45 4.2039e-45 2.8026e-45 1.4013e-45 -1.07374e+08 -1.07374e+08 -0.171738 0.0169681 -0.12809 0.215015 -0.177017 0.00818324 0.0701267 -0.134732 -0.119332 -0.156542 -0.183693 0.177653 -0.0652033 -0.05129 -0.0774314 -0.0138738 -0.206785 0.0342959 0.0418634 0.0918676 -0.114144 -1.07374e+08 -1.07374e+08 -1.07374e+08 -0.0153737 -0.0141386 -0.00170327 -0.0243574 1 1 1 1 0 0 0 0 0 0 0 0 1 1 1 1 0.995037 0.995037 0.995037 0.995037 0.01 -1.07374e+08 -1.07374e+08 -1.07374e+08 -0.0154504 -0.0142091 -0.00171176 -0.0244789 0.25 0.25 0.25 0.25 -0.0618017 -0.0568364 -0.00684705 -0.0979155 -0.0082895 -0.00555601 -0.00196131 0.0206646

As you see, there are not even close except for the first sample!!! I am developing an audio application. I used onnxruntime to run my model in C++, and it works fine with the same result compared to Python. However, onnxruntime is not efficient when it comes to audio real-time processing. Because of that, I am trying to use RTNeural for the inference engine. I would appreciate it if you could help me to figure out what the problem is with my code.

I have already looked at the example you provided. But it's just a dynamic model; I couldn't find the static implementation of this example. Since the processing time is crucial for me, I have to go for compile-time API.

Thanks.

jatinchowdhury18 commented 11 months ago

Hmm, it seems like you're conflating the number of inputs/outputs that the model has with the length of the "time axis" that the model is being asked to process. I've modified the example script that I shared to use your model architecture, and then modified your C++ code accordingly. This appears to give the equivalent output to the TensorFlow model.

https://gist.github.com/jatinchowdhury18/2fb5c212283b6db89fce1170d4aad6cb

alireza1325 commented 11 months ago

Thank you so much for correcting my mistake. Now, it seems that the model calls the forward() function for generating every single output sample. This leads to a significantly high processing time for generating the output array. I thought it would be the same as the Keras or onnxruntime, which requires only one model.predict () call for producing the whole output.

Is there any better method to get the output with a single .forward() call? I run the same network with onnxruntime, and it takes around 0.3 ms to produce 128 output samples. However, using RTNeural, it takes around 50 ms to produce the same output on the same operating system and laptop. I thought using RTNeural, I could get better performance than onnxruntime since it's much lighter than onnxruntime. Am I missing something?

jatinchowdhury18 commented 11 months ago

Right, so RTNeural is designed to be able to process streams of data that are arbitrary in length. If the length of the signal is known ahead of time then it is possible to process the entire signal in one go, as Keras and others allow, which sometimes allows for other optimisations to take place, but RTNeural currently does not support this type of inferencing.

That said, 50 milliseconds definitely seems longer than it should be. Running the C++ code that I shared in the previous message with compiler optimizations turned on (-O3) and using the Eigen backend, I measured how long the inferencing time took (see below), and was consistently getting measurements around 70 microseconds. Obviously, we're testing on different machines, but that should be the ballpark to shoot for. With the XSIMD backend, I was getting measurements closer to 130 microseconds. I'm currently testing on an ARM CPU, but I imagine with an Intel CPU with AVX support it might be possible to achieve even better performance.

namespace chrono = std::chrono;
const auto start = chrono::high_resolution_clock::now();
for (size_t i = 0; i < inputs.size(); ++i)
{
    testOutputs[i] = modelt.forward(inputs.data() + i);
}
const auto duration = chrono::high_resolution_clock::now() - start;
std::cout << "Time taken by function: " << chrono::duration_cast<chrono::microseconds>(duration).count() << " microseconds" << std::endl;

For obtaining better performance with RTNeural, here's some ideas that are worth trying:

I thought using RTNeural, I could get better performance than onnxruntime since it's much lighter than onnxruntime.

I'm not sure I fully understand this... RTNeural is "lighter" than onnxruntime in the sense that it has less source code, sure, but that doesn't necessarily make it faster or more efficient. RTNeural is designed for a fairly specific use case, and tried to achieve the best possible performance within the constraints of that use case. For example, with strictly feed-forward architectures like you're working with here, it may definitely be possible for other inferencing engines to out-perform RTNeural by processing the entire signal at once.

I would also wonder if onnxruntime may be doing a better job of making use of the full capabilities of your computer, either by utilising specific hardware that is available (GPU, TPU, Apple's "Neural Engine"), or making better use of the SIMD instruction sets that your CPU might support. RTNeural currently only supports running inference on the CPU, and while it is possible to enable more advanced SIMD instruction sets for RTNeural to make use of, that's generally something the programmer would have to do manually, whereas onnxruntime might do it for you out-of-the-box.

alireza1325 commented 11 months ago

Hello,

Thank you so much for the detailed information. Sorry, you are right; I didn't turn on optimization. Now I am getting around 110 microseconds to produce the entire 128 samples output on my laptop. Since the purpose of this model is to be run on Raspberry Pi, I compiled and run it on Raspberry Pi 4. On this machine, I am getting around 1222 microseconds using the Eigen backend and 2700 using the XSIMD backend with -O3 optimization.

I have to decrease the processing time on Raspberry Pi down to around 300 microseconds so that it can be run in real-time. I just need the last 16 samples out of 128 output samples. My model originally produced 16 samples; since you said the RTNeural doesn't support the "Valid" padding of keras, I changed them to "causal" padding. Now the model produces 128-16 = 112 unused samples. In other words, it calls model.forward() 112 more times. When I change 128 forward calls to 16, I am getting 141.1 microseconds with the Eigen backend, which is great, and it is half of the processing time that I got using Onnxruntime. Literally, RTNeural processes the model 2 times faster than the onnxruntime in this case.

My question is: is there any trick so that I can just produce 16 output samples of the model instead of 128? How hard is that If I want to implement "valid" padding of Keras on RTNeural myself?

Sorry to ask too many questions; I truly appreciate your time. Thank you.

jatinchowdhury18 commented 11 months ago

No problem, happy to help!

100-1000 microseconds definitely seems like the right ballpark for the time needed for infernecing with this network architecture.

I have to decrease the processing time on Raspberry Pi down to around 300 microseconds so that it can be run in real-time. I just need the last 16 samples out of 128 output samples. My model originally produced 16 samples; since you said the RTNeural doesn't support the "Valid" padding of keras, I changed them to "causal" padding. Now the model produces 128-16 = 112 unused samples. In other words, it calls model.forward() 112 more times. When I change 128 forward calls to 16, I am getting 141.1 microseconds with the Eigen backend, which is great, and it is half of the processing time that I got using Onnxruntime. Literally, RTNeural processes the model 2 times faster than the onnxruntime in this case.

The way that convolutional layers and "causal padding" work is that since convolution layers are "stateful" they need some number of inputs in order to build up their full "state". I suppose you could call forward 112 times when loading the weights for your model, which would allow the model to build up its state, and then only measure the time needed for the final 16 calls? Or alternatively, you could figure out what that state should be ahead of time, and load it into the convolutional layers directly (RTNeural doesn't currently have an API for this, but I guess we could make one).

That said, I'm not sure I'm understanding your use-case, at least as it relates to the usage of RTNeural. As I mentioned, RTNeural is designed to process an arbitrary-length stream of data, so the number of samples being processed shouldn't really matter. If you're able to share some more of the details about what your network is doing in your use-case, I might be able to provide some more insight.

alireza1325 commented 11 months ago

Thanks for your advice. Actually, my model is just a set of 1D standard convolution, not temporal convolution. I read the source code of the RTNeural, for example, conv1d_eigen.h. In the file, you mentioned that " This implementation was designed to be used for "temporal convolution", so the layer has a "state" made up of past inputs to the layer." But my network doesn't need to know what are the previous outputs. All layers in my network are completely stateless, and the output can be calculated using the present input array.

The model that I have designed is for active noise cancelation. It needs microphone signals as input and produces anti-noise samples. Since I need to run the model in a real-time thread, the processing time is crucial.

jatinchowdhury18 commented 11 months ago

Thanks for the extra info, that use-case definitely makes sense. RTNeural does support "stateless" 1D convolutions, but up to now, they've only been used as a "building block" for the Conv2D layer. So at the moment RTNeural::Conv1DStateless does exist, but it can't be used within an RTNeural::ModelT because of some mismatches in the way inputs and outputs of the layers are handled.

In general, I think this is something we need to think about more generally with RTNeural, however, it what is possible right now, is to just use the layers directly, and manage the layer inputs/outputs on your own. I went ahead and made an example of this using the network architecture that you shared, and put it on a branch. I had to make a few other changes so that the layer weights got set correctly, but if you pull that branch, the example should "just work" with the Eigen backend. The STL backend almost works, I think there's something off with how the layer weights are getting set, and I haven't tried to get XSIMD working yet, since that backend tends to have the most issues with I/O mismatches between layers.

When timing the example on my computer I'm seeing that the network can run in approximately 17-18 microseconds. I'm imagining that on the Raspberry Pi it might be something like 2-3 times slower, but hopefully that should still be fast enough for your use-case.

Please let me know if you have any trouble with that branch, or if you have some ideas for making it easier to use these kinds of network architectures within RTNeural in the future!

alireza1325 commented 11 months ago

Hi Jatin,

Thank you so much. Your code and modifications worked great. Running the model using Eigen on Intel CPU took approximately 18 microseconds, and on Raspberry Pi, around 286 microseconds which is awesome. It was the same results I got using the onnxruntime library. It seems that RTNeural has many unlocked capabilities. From the way you implemented the network, it's apparent that it's possible to implement models using the functional approach rather than the sequential method. Using the functional approach, one can deploy more complex models that have many skip connections or concatenations like DenseNet or ResNet. In all provided examples in the RTNeural repository, the layers were created sequentially. Because of that, at first, I thought RTNeural has very limited capabilities, which later turned out it's not true.

If you could work more on stateless convolutions and replace them with conv1D layers, it would be great. The conv1D layer of RTNeural is the implementation of temporal convolutions, which are developed for capturing long-range dependencies in input. Temporal convolution networks (TCN) can be replaced with LSTM or GRU layers. However, tensorflow.keras.layers.Conv1D is standard convolution which is stateless.

Moreover, now you could write the code in a way that produces all output samples of the model with a single model. forward() call. I think this could reduce the computational complexity of such models that have an array (1D or 2D) as output instead of a single number.

In addition, I think it shouldn't be difficult to add GPU capability to RTNeural since you already have layers classes for three methods of STL, EIGEN, and XSIMD separately. You just need to define kernel functions for convs, dense, activations, lstm layers like the one I did for dense and activations for MNIST dataset classification here.

Once again, I truly appreciate your help. It would take ages for me to change the RTNeural code to make it suitable for my model.

jake-is-ESD-protected commented 2 months ago

Hi, I just came across this answer because I was wondering the same thing. Since you want to expand documentation @jatinchowdhury18 as discussed in #135, I would highly suggest that you include this as well! The answer is as clean and simple as it gets, it just needs to be mentioned in the main README to be found easier.

I might sound pedantic, but I would suggest closing this issue because the question is answered. The resulting discussion is very interesting, but does not have to do anything with the original question, which was answered in the first response. People might come here expecting a solution and see a very long, open issue, which might cause more frustration than necessary, as the answer is plain and simple: model.getOutputs()