nok / sklearn-porter

Transpile trained scikit-learn estimators to C, Java, JavaScript and others.
BSD 3-Clause "New" or "Revised" License
1.28k stars 170 forks source link

C Code generated is not correct. #23

Closed sanjivsoni17 closed 6 years ago

sanjivsoni17 commented 6 years ago

Features is accessed inside the function predict. The scope of variable features is within main function. It should be either a global variable or passes as function parameter.

#include <stdlib.h>
#include <stdio.h>
#include <math.h>

int predict(float atts[2]) {

    int classes[2];

    if (features[0] <= 5.43762493134) {
        if (features[1] <= 5.74491977692) {
            if (features[0] <= 3.51197504997) {
                classes[0] = 10; 
                classes[1] = 0; 
            } else {
                classes[0] = 0; 
                classes[1] = 1; 
            }
        } else {
            if (features[1] <= 16.6829204559) {
                if (features[0] <= 2.67515516281) {
                    if (features[1] <= 11.5629148483) {
                        if (features[1] <= 7.29798984528) {
                            if (features[1] <= 6.13995504379) {
                                classes[0] = 1; 
                                classes[1] = 0; 
                            } else {
                                classes[0] = 0; 
                                classes[1] = 3; 
                            }
                        } else {
                            if (features[0] <= 1.60292005539) {
                                if (features[1] <= 8.0366601944) {
                                    classes[0] = 3; 
                                    classes[1] = 0; 
                                } else {
                                    if (features[1] <= 9.11940002441) {
                                        classes[0] = 0; 
                                        classes[1] = 2; 
                                    } else {
                                        if (features[0] <= 1.21078002453) {
                                            if (features[0] <= 1.11364006996) {
                                                classes[0] = 1; 
                                                classes[1] = 0; 
                                            } else {
                                                classes[0] = 0; 
                                                classes[1] = 1; 
                                            }
                                        } else {
                                            classes[0] = 2; 
                                            classes[1] = 0; 
                                        }
                                    }
                                }
                            } else {
                                classes[0] = 6; 
                                classes[1] = 0; 
                            }
                        }
                    } else {
                        if (features[0] <= 2.35693502426) {
                            classes[0] = 0; 
                            classes[1] = 7; 
                        } else {
                            classes[0] = 1; 
                            classes[1] = 0; 
                        }
                    }
                } else {
                    if (features[1] <= 16.5127105713) {
                        if (features[1] <= 12.1385450363) {
                            if (features[1] <= 6.92804527283) {
                                if (features[1] <= 6.25199985504) {
                                    classes[0] = 0; 
                                    classes[1] = 4; 
                                } else {
                                    if (features[0] <= 5.02503490448) {
                                        classes[0] = 2; 
                                        classes[1] = 0; 
                                    } else {
                                        classes[0] = 0; 
                                        classes[1] = 1; 
                                    }
                                }
                            } else {
                                if (features[1] <= 10.6784753799) {
                                    classes[0] = 0; 
                                    classes[1] = 9; 
                                } else {
                                    if (features[1] <= 10.7935905457) {
                                        classes[0] = 1; 
                                        classes[1] = 0; 
                                    } else {
                                        classes[0] = 0; 
                                        classes[1] = 5; 
                                    }
                                }
                            }
                        } else {
                            if (features[0] <= 4.75841522217) {
                                if (features[0] <= 3.42268514633) {
                                    classes[0] = 1; 
                                    classes[1] = 0; 
                                } else {
                                    classes[0] = 0; 
                                    classes[1] = 5; 
                                }
                            } else {
                                classes[0] = 2; 
                                classes[1] = 0; 
                            }
                        }
                    } else {
                        classes[0] = 1; 
                        classes[1] = 0; 
                    }
                }
            } else {
                if (features[0] <= 4.17648506165) {
                    classes[0] = 6; 
                    classes[1] = 0; 
                } else {
                    if (features[0] <= 4.91468000412) {
                        classes[0] = 0; 
                        classes[1] = 3; 
                    } else {
                        classes[0] = 2; 
                        classes[1] = 0; 
                    }
                }
            }
        }
    } else {
        if (features[0] <= 7.70522975922) {
            if (features[0] <= 7.64461517334) {
                if (features[0] <= 6.52222013474) {
                    if (features[0] <= 6.49937534332) {
                        if (features[1] <= 8.1920003891) {
                            if (features[1] <= 8.07668018341) {
                                classes[0] = 0; 
                                classes[1] = 4; 
                            } else {
                                classes[0] = 1; 
                                classes[1] = 0; 
                            }
                        } else {
                            classes[0] = 0; 
                            classes[1] = 14; 
                        }
                    } else {
                        classes[0] = 1; 
                        classes[1] = 0; 
                    }
                } else {
                    if (features[1] <= 13.1301851273) {
                        classes[0] = 0; 
                        classes[1] = 41; 
                    } else {
                        if (features[1] <= 13.5656652451) {
                            classes[0] = 1; 
                            classes[1] = 0; 
                        } else {
                            classes[0] = 0; 
                            classes[1] = 7; 
                        }
                    }
                }
            } else {
                classes[0] = 1; 
                classes[1] = 0; 
            }
        } else {
            classes[0] = 0; 
            classes[1] = 183; 
        }
    }

    int index = 0;
    for (int i = 0; i < 2; i++) {
        index = classes[i] > classes[index] ? i : index;
    }
    return index;
}

int main(int argc, const char * argv[]) {

    /* Features: */
    double features[argc-1];
    int i;
    for (i = 1; i < argc; i++) {
        features[i-1] = atof(argv[i]);
    }

    /* Prediction: */
    printf("%d", predict(features));
    return 0;

}
nok commented 6 years ago

Thanks, now it's fixed by the commit #c3d23f0.