nok / sklearn-porter

Transpile trained scikit-learn estimators to C, Java, JavaScript and others.
BSD 3-Clause "New" or "Revised" License
1.28k stars 170 forks source link

Single feature RandomForestClassifier throws index out of range exception #5

Closed lichard49 closed 7 years ago

lichard49 commented 7 years ago

I've built a very simple single feature RandomForestClassifier:

from sklearn.ensemble import RandomForestClassifier
import numpy as np

from sklearn_porter import Porter

rf = RandomForestClassifier()
features = [[i] for i in xrange(0, 10)]
labels = [i > 5 for i in xrange(0, 10)]

rf.fit(features, labels)

for feature in xrange(-20, 20):
    print feature, '->', rf.predict(np.array([feature]).reshape(1, -1))

result = Porter(language='java').port(rf)
print result

which gives the following stack trace:

Traceback (most recent call last):
  File "generateModel.py", line 21, in <module>
    result = Porter(language='java').port(rf)
  File "/usr/local/lib/python2.7/dist-packages/sklearn_porter/__init__.py", line 72, in port
    ported_model = instance.port(model)
  File "/usr/local/lib/python2.7/dist-packages/sklearn_porter/classifier/RandomForestClassifier/__init__.py", line 84, in port
    return self.predict()
  File "/usr/local/lib/python2.7/dist-packages/sklearn_porter/classifier/RandomForestClassifier/__init__.py", line 95, in predict
    return self.create_class(self.create_method())
  File "/usr/local/lib/python2.7/dist-packages/sklearn_porter/classifier/RandomForestClassifier/__init__.py", line 198, in create_method
    tree = self.create_single_method(idx, model)
  File "/usr/local/lib/python2.7/dist-packages/sklearn_porter/classifier/RandomForestClassifier/__init__.py", line 162, in create_single_method
    indices.append([str(j) for j in range(model.n_features_)][i])
IndexError: list index out of range

The line in question involves indexing into the feature vector, but sometimes the index is negative, which is fine except when it wraps around the list twice. In this case, model.n_features_ is 1 but i (the index) is -2, giving the list out of range exception. What is the best solution for this? Would simply taking the modulus of the index by the length of list be correct?

Thanks!

nok commented 7 years ago

Thanks, I will check (and fix) it.

nok commented 7 years ago

Hello @lichard49,

I fixed the described behaviour https://github.com/nok/sklearn-porter/commit/ee66a1d877158a7012ceec55a092d53c983686a9. This is the result of your model:

class Tmp {
    public static int predict_0(float[] atts) {
        int[] classes = new int[2];

        if (atts[0] <= 5.5) {        
            classes[0] = 6;         
            classes[1] = 0;     
        } else {        
            classes[0] = 0;         
            classes[1] = 4;     
        }
        int class_idx = 0;
        int class_val = classes[0];
        for (int i = 1; i < 2; i++) {
            if (classes[i] > class_val) {
                class_idx = i;
                class_val = classes[i];
            }
        }
        return class_idx;
    }

    public static int predict_1(float[] atts) {
        int[] classes = new int[2];

        if (atts[0] <= 5.5) {        
            classes[0] = 7;         
            classes[1] = 0;     
        } else {        
            classes[0] = 0;         
            classes[1] = 3;     
        }
        int class_idx = 0;
        int class_val = classes[0];
        for (int i = 1; i < 2; i++) {
            if (classes[i] > class_val) {
                class_idx = i;
                class_val = classes[i];
            }
        }
        return class_idx;
    }

    public static int predict_2(float[] atts) {
        int[] classes = new int[2];

        if (atts[0] <= 5.0) {        
            classes[0] = 6;         
            classes[1] = 0;     
        } else {        
            classes[0] = 0;         
            classes[1] = 4;     
        }
        int class_idx = 0;
        int class_val = classes[0];
        for (int i = 1; i < 2; i++) {
            if (classes[i] > class_val) {
                class_idx = i;
                class_val = classes[i];
            }
        }
        return class_idx;
    }

    public static int predict_3(float[] atts) {
        int[] classes = new int[2];

        if (atts[0] <= 5.5) {        
            classes[0] = 5;         
            classes[1] = 0;     
        } else {        
            classes[0] = 0;         
            classes[1] = 5;     
        }
        int class_idx = 0;
        int class_val = classes[0];
        for (int i = 1; i < 2; i++) {
            if (classes[i] > class_val) {
                class_idx = i;
                class_val = classes[i];
            }
        }
        return class_idx;
    }

    public static int predict_4(float[] atts) {
        int[] classes = new int[2];

        if (atts[0] <= 5.5) {        
            classes[0] = 7;         
            classes[1] = 0;     
        } else {        
            classes[0] = 0;         
            classes[1] = 3;     
        }
        int class_idx = 0;
        int class_val = classes[0];
        for (int i = 1; i < 2; i++) {
            if (classes[i] > class_val) {
                class_idx = i;
                class_val = classes[i];
            }
        }
        return class_idx;
    }

    public static int predict_5(float[] atts) {
        int[] classes = new int[2];

        if (atts[0] <= 6.0) {        
            classes[0] = 6;         
            classes[1] = 0;     
        } else {        
            classes[0] = 0;         
            classes[1] = 4;     
        }
        int class_idx = 0;
        int class_val = classes[0];
        for (int i = 1; i < 2; i++) {
            if (classes[i] > class_val) {
                class_idx = i;
                class_val = classes[i];
            }
        }
        return class_idx;
    }

    public static int predict_6(float[] atts) {
        int[] classes = new int[2];

        if (atts[0] <= 5.5) {        
            classes[0] = 6;         
            classes[1] = 0;     
        } else {        
            classes[0] = 0;         
            classes[1] = 4;     
        }
        int class_idx = 0;
        int class_val = classes[0];
        for (int i = 1; i < 2; i++) {
            if (classes[i] > class_val) {
                class_idx = i;
                class_val = classes[i];
            }
        }
        return class_idx;
    }

    public static int predict_7(float[] atts) {
        int[] classes = new int[2];

        if (atts[0] <= 5.5) {        
            classes[0] = 7;         
            classes[1] = 0;     
        } else {        
            classes[0] = 0;         
            classes[1] = 3;     
        }
        int class_idx = 0;
        int class_val = classes[0];
        for (int i = 1; i < 2; i++) {
            if (classes[i] > class_val) {
                class_idx = i;
                class_val = classes[i];
            }
        }
        return class_idx;
    }

    public static int predict_8(float[] atts) {
        int[] classes = new int[2];

        if (atts[0] <= 5.5) {        
            classes[0] = 5;         
            classes[1] = 0;     
        } else {        
            classes[0] = 0;         
            classes[1] = 5;     
        }
        int class_idx = 0;
        int class_val = classes[0];
        for (int i = 1; i < 2; i++) {
            if (classes[i] > class_val) {
                class_idx = i;
                class_val = classes[i];
            }
        }
        return class_idx;
    }

    public static int predict_9(float[] atts) {
        int[] classes = new int[2];

        if (atts[0] <= 5.5) {        
            classes[0] = 6;         
            classes[1] = 0;     
        } else {        
            classes[0] = 0;         
            classes[1] = 4;     
        }
        int class_idx = 0;
        int class_val = classes[0];
        for (int i = 1; i < 2; i++) {
            if (classes[i] > class_val) {
                class_idx = i;
                class_val = classes[i];
            }
        }
        return class_idx;
    }

    public static int predict(float[] atts) {
        int n_classes = 2;
        int[] classes = new int[n_classes];
        classes[Tmp.predict_0(atts)]++;
        classes[Tmp.predict_1(atts)]++;
        classes[Tmp.predict_2(atts)]++;
        classes[Tmp.predict_3(atts)]++;
        classes[Tmp.predict_4(atts)]++;
        classes[Tmp.predict_5(atts)]++;
        classes[Tmp.predict_6(atts)]++;
        classes[Tmp.predict_7(atts)]++;
        classes[Tmp.predict_8(atts)]++;
        classes[Tmp.predict_9(atts)]++;

        int class_idx = 0;
        int class_val = classes[0];
        for (int i = 1; i < n_classes; i++) {
            if (classes[i] > class_val) {
                class_idx = i;
                class_val = classes[i];
            }
        }
        return class_idx;
    }

    public static void main(String[] args) {
        if (args.length == 1) {
            float[] atts = new float[args.length];
            for (int i = 0, l = args.length; i < l; i++) {
                atts[i] = Float.parseFloat(args[i]);
            }
            System.out.println(Tmp.predict(atts));
        }
    }
}

I pushed the changes to the development branch (master). By using the following commands you can use the latest changes:

pip uninstall -y sklearn-porter
pip install --no-cache-dir https://github.com/nok/sklearn-porter/zipball/master

Finally thanks for your hint! Feel free to reopen this issue.

Happy coding, Darius 🌵

lichard49 commented 7 years ago

Verified, thanks!