majianjia / nnom

A higher-level Neural Network library for microcontrollers.
Apache License 2.0
815 stars 235 forks source link

Incosistent Accuracy Between Python and C Implementation #195

Closed codygillespie closed 1 year ago

codygillespie commented 1 year ago

Thank you @majianjia for this library.

I have been running into an issue where the accuracy in the C implementation of my model is far lower than what I am seeing in Python. I have put together a complete example here: Complete NNoM Sample.

My input data when training is simply 12 values ranging between 0 and 127 inclusive. My output labels are simply 0 and 1. All my data is located in data.py in the linked repo: image

My model looks like the following and is located in model.py in the linked repo:

    model = Sequential()
    model.add(Conv1D(16, kernel_size=5, input_shape=(INPUT_SIZE, 1)))
    model.add(Activation('sigmoid'))
    model.add(Conv1D(32, kernel_size=3))
    model.add(Activation('relu'))
    model.add(Flatten())
    model.add(Dense(32))
    model.add(Activation('relu'))
    model.add(Dropout(0.5))
    model.add(Dense(2))
    model.add(Activation('softmax'))
    model.compile(optimizer='adam', loss='sparse_categorical_crossentropy', metrics=['accuracy'])

My training script is in main.py and accepts arguments for the number of epochs to use as well as details about the test/train split. I have been using the following arguments to train for 5 epochs on 80 percent of my data: python ./python_model/main.py -s .8 -r 42 -e 5. My main.py script looks like the following:

import argparse

import numpy as np
from keras import Model
from sklearn.model_selection import train_test_split

from model import compile_model
from constants import MODEL_NAME
from headers import save_headers
from data import ALL_DATA

def main():
    # Required command line arguments:
    # -s: double between 0 and 1, the fraction of the data to be used for training
    # -r: int, the random seed to be used for the random split of the data
    # -e: int, the number of epochs to train the model
    # python .\python_model\main.py -s .8 -r 42 -e 50
    parser = argparse.ArgumentParser(description='Train a neural network model.')
    parser.add_argument('-s', '--split', type=float, required=True,
                        help='fraction of data to use for training (between 0 and 1)')
    parser.add_argument('-r', '--random_seed', type=int, required=True,
                        help='random seed for data split')
    parser.add_argument('-e', '--epochs', type=int, required=True,
                        help='number of epochs to train the model')

    args = parser.parse_args()
    split: float = args.split
    random_seed: int = args.random_seed
    epochs: int = args.epochs

    train, test = train_test_split(ALL_DATA, train_size=split, random_state=random_seed)
    x_train = np.array([x['data'] for x in train])
    y_train = np.array([x['label'] for x in train])
    x_test = np.array([x['data'] for x in test])
    y_test = np.array([x['label'] for x in test])

    model: Model = compile_model()
    model.fit(x_train, y_train, epochs=epochs, batch_size=32, validation_data=(x_test, y_test))
    model.summary()

    model.save(MODEL_NAME)
    loss, acc = model.evaluate(x_test, y_test)
    print(f'loss: {loss}, acc: {acc}')
    save_headers(x_test, y_test)

if __name__ == "__main__":
    main()

When running this script I am seeing the following loss and accuracy:

loss: 0.32322487235069275, acc: 0.8782222270965576

However, when I run the model in C with the exported weights.h, the accuracy I am seeing is far lower:

Total tests: 3375
Correct tests: 1260
Accuracy: 0.373333

My main.c program that runs the model in C looks like the following and is in the linked repo:

#include <stdio.h>
#include "nnom.h"
#include "weights.h"
#include "test_data.h"
#include "test_labels.h"

int main()
{
    nnom_model_t *model;
    model = nnom_model_create();
    model_run(model);

    int total_tests = 0;
    int test_correct = 0;
    int test_incorrect = 0;

    int predicted_label;
    float probability;

    for (int i = 0; i < samples; ++i) {
        total_tests++;
        for (int j = 0; j < size_per_sample; ++j) {
            nnom_input_data[j] = (test_data[i][j]);
        }
        nnom_predict(model, &predicted_label, &probability);

        if (predicted_label == test_data_labels[i]){
            test_correct++;
        }
        else {
            test_incorrect++;
        }
    }

    printf("Total tests: %d\n", total_tests);
    printf("Correct tests: %d\n", test_correct);
    printf("Accuracy: %f\n", (float)test_correct / (float)total_tests);

    return 0;
}

The linked repo contains powershell scripts for building the model in python, building the C application and running the C application. Would you be able to help me resolve the discrepancy in the implementations? Any help would be appreciated. Thank you again.

majianjia commented 1 year ago

I dont know your training setting, such as epoch number.

I recommend setting lower epoch number because unaware-quantization training can lead to extreme weights and data which cause 8 bit quantization out of effective ranges. For example, the majority of data are laying around +-8 but some data goes up to -1000 for the input of some activations. When it comes to quantization, it needs to cover the range of 1000, set Q number to -2, then most of the data are down sampled.

You may try

codygillespie commented 1 year ago

I dont know your training setting, such as epoch number.

I recommend setting lower epoch number because unaware-quantization training can lead to extreme weights and data which cause 8 bit quantization out of effective ranges. For example, the majority of data are laying around +-8 but some data goes up to -1000 for the input of some activations. When it comes to quantization, it needs to cover the range of 1000, set Q number to -2, then most of the data are down sampled.

You may try

  • add batchnorm after conv. This constrains the data range.
  • reduce epoch number to 10 to see if it improves. Lower epoch number
  • enable "KLD" quantisation method instead of the default "max-min". KLD deletes extreme values. but sometime reduce affect the result.

Thank you for the help @majianjia.

We ended up modifying our data preprocessing and input layer and that seemed to help for quantization. Simply having more neurons in our input layer seemed to mitigate the issue. Thank you again.