Kiyoshika / CppEZML

A work in progress C++ machine learning library designed to be very easy to use. Everything pretty much written from scratch.
0 stars 0 forks source link

LinearRegression.h - Allow user-defined loss functions #6

Closed Kiyoshika closed 3 years ago

Kiyoshika commented 3 years ago

Overload loss function to allow optional function pointer parameters for users to define their own loss functions.

Kiyoshika commented 3 years ago

Implemented. See example below comparing default squared error with absolute error

#include <iostream>
#include <vector>
#include <math.h>
#include "data/DataSet.h"
#include "models/regression/LinearRegression.h"

using namespace std;

double absolute_error(double actual_y, double predicted_y) {
    return abs(actual_y - predicted_y);
}

int main() {

    // generated data from Python's sklearn make_regression
    // make_regression(n_samples = 2000, n_features = 100, noise = 0.7)
    // 30% split on training / testing
    DataSet xtrain, xtest, ytrain, ytest;
    xtrain.load("xtrain.csv");
    xtest.load("xtest.csv");
    ytrain.load("ytrain.csv");
    ytest.load("ytest.csv");

    Regressor *se = new LinearRegression(true); // squared error loss
    Regressor *ae = new LinearRegression(true, 1000, 0.001, &absolute_error); // absolute error loss

    cout << "Squared Error Model:" << "\n";
    se->fit(xtrain.cast_data_double(), ytrain.cast_target_double());
    cout << "\nAbsolute Error Model:" << "\n";
    ae->fit(xtrain.cast_data_double(), ytrain.cast_target_double());

    vector<double> se_preds = se->predict(xtest.cast_data_double());
    vector<double> ae_preds = ae->predict(xtest.cast_data_double());

    cout << "\n";
    cout << "Squared Error RMSE: " << se->get_rmse(ytest.cast_target_double(), se_preds) << "\n";
    cout << "Absolute Error RMSE: " << ae->get_rmse(ytest.cast_target_double(), ae_preds) << "\n";
    delete se, ae;

}

OUTPUT FROM ABOVE:

Squared Error Model:
Total loss at iteration #0: 6.44143e+06
Total loss at iteration #100: 764.165
Total loss at iteration #200: 764.165
Total loss at iteration #300: 764.165
Total loss at iteration #400: 764.165
Total loss at iteration #500: 764.165
Total loss at iteration #600: 764.165
Total loss at iteration #700: 764.165
Total loss at iteration #800: 764.165
Total loss at iteration #900: 764.165

Absolute Error Model:
Total loss at iteration #0: 174918
Total loss at iteration #100: 43676.8
Total loss at iteration #200: 812.007
Total loss at iteration #300: 812.38
Total loss at iteration #400: 813.214
Total loss at iteration #500: 810.909
Total loss at iteration #600: 811.659
Total loss at iteration #700: 813.882
Total loss at iteration #800: 812.16
Total loss at iteration #900: 811.485

Squared Error RMSE: 0.75102
Absolute Error RMSE: 0.75168