oracle / tribuo

Tribuo - A Java machine learning library
https://tribuo.org
Apache License 2.0
1.24k stars 172 forks source link

Clustering Issue with Loading the Data #358

Closed Mohammed-Ryiad-Eiadeh closed 7 months ago

Mohammed-Ryiad-Eiadeh commented 7 months ago

dear,

my code is :

    // fetch data
    var dataSource = new CSVLoader<>(new ClusteringFactory()).loadDataSource(Paths.get("data.csv"), "Cluster");

    // split the data
    var splitter = new TrainTestSplitter<>(dataSource, 0.6, Trainer.DEFAULT_SEED);
    var train = new MutableDataset<>(splitter.getTrain());
    var test = new MutableDataset<>(splitter.getTest());

    // K_Means cluster
    var numberOfCentroid = 5;
    var trainer = new KMeansTrainer(numberOfCentroid, /* centroids */
            100, /* iterations */
            DistanceType.L1.getDistance(), /* distance function */
            Runtime.getRuntime().availableProcessors(), /* number of compute threads */
            Trainer.DEFAULT_SEED /* RNG seed */
    );
    // train the model
    var sTime = System.currentTimeMillis();
    var learner = trainer.train(train);
    var eTime = System.currentTimeMillis();
    System.out.println("Training with " + numberOfCentroid + " clusters took : " + Util.formatDuration(sTime, eTime));

    // test the model
    System.out.println(new ClusteringEvaluator().evaluate(learner, test));
And my data looks like: Feature0 Feature1 Cluster
4.739796 0.876763 0
-1.37194 0.768297 0
1.206231 0.124435 0
-0.0677 0.670152 0
1.383023 1.063909 0
0.948694 0.295163 0
0.129875 0.897475 0
2.093981 0.962801 0
0.966418 2.514052 0
2.26122 2.143702 0
3.823891 5.302061 1
4.174474 4.311026 1
0.987644 2.291385 1
4.066038 2.630198 1
2.384297 1.817028 1
2.105418 2.71869 1
3.809 3.569969 1
3.798313 3.082428 1
2.648232 3.845926 1
4.256556 2.006187 1
4.154176 8.123741 2
6.343113 7.042842 2
5.018035 6.386035 2
6.124884 6.221987 2
7.4189 5.941632 2
6.162474 5.681063 2
5.192215 5.359867 2
4.824744 5.482966 2
4.080328 4.913133 2
6.548129 6.864366 2
10.56087 9.481709 3
9.01082 7.915724 3
8.579698 7.859367 3
9.513544 7.744482 3
9.334346 9.108457 3
8.355501 9.255687 3
8.077686 9.389924 3
9.569906 9.77651 3
9.186204 8.75315 3
7.819217 8.659295 3

And I got the following exception:

Exception in thread "main" java.lang.NumberFormatException: For input string: "0.0" at java.base/java.lang.NumberFormatException.forInputString(NumberFormatException.java:67) at java.base/java.lang.Integer.parseInt(Integer.java:668) at java.base/java.lang.Integer.parseInt(Integer.java:786) at org.tribuo.clustering.ClusteringFactory.generateOutput(ClusteringFactory.java:87) at org.tribuo.clustering.ClusteringFactory.generateOutput(ClusteringFactory.java:40) at org.tribuo.data.columnar.processors.response.FieldResponseProcessor.process(FieldResponseProcessor.java:226) at org.tribuo.data.columnar.RowProcessor.generateExample(RowProcessor.java:340) at org.tribuo.data.columnar.ColumnarDataSource$InnerIterator.hasNext(ColumnarDataSource.java:126) at org.tribuo.evaluation.TrainTestSplitter.(TrainTestSplitter.java:94) at Main.org.K_Means.main(K_Means.java:22)

Can you help me with this please?!

Mohammed-Ryiad-Eiadeh commented 7 months ago

I got a solution that actually surprised me, the response category should be placed as the first column in the dataset.

Craigacp commented 7 months ago

That shouldn't be necessary, it should load the response from whatever column is specified.

Mohammed-Ryiad-Eiadeh commented 7 months ago

Okay, let me check something before proceed.

Mohammed-Ryiad-Eiadeh commented 7 months ago

Well, I use TableSaw library to save my csv file after generating my own data, and after each time I creat the file, I have to open and re-name the response label to prevent the program from throwing the exception. But till now I don't know why.

Craigacp commented 7 months ago

What program do you use to rename the column?

Mohammed-Ryiad-Eiadeh commented 7 months ago

This is my code for generating my own data (don't care about fields and constructor),

    // The matrix to hold data for all classes
    double[][] columns = new double[numOfInstances * 4][numOfVariables + 1];  // The entire dataset

    // For Class 0
    double[][] tabledDataC0 = new double[numOfInstances][numOfVariables];
    IntStream.range(0, tabledDataC0.length).parallel().forEach(i -> IntStream.range(0, tabledDataC0[0].length).parallel().forEach(ii -> tabledDataC0[i][ii] = mean + Std * random.nextGaussian()));
    for (int i = 0; i < tabledDataC0.length; i++) {
        for (int j = 0; j < tabledDataC0[0].length; j++) {
            columns[i][j] = tabledDataC0[i][j];
            columns[i][numOfVariables] = 0;
        }
    }
    // For Class 1
    double[][] tabledDataC1 = new double[numOfInstances][numOfVariables];
    IntStream.range(0, tabledDataC1.length).parallel().forEach(i -> IntStream.range(0, tabledDataC1[0].length).parallel().forEach(ii -> tabledDataC1[i][ii] = 3 * mean + Std * random.nextGaussian()));
    int row0 = 0;
    for (int i = tabledDataC1.length; i < (tabledDataC1.length * 2); i++) {
        for (int j = 0; j < tabledDataC1[0].length; j++) {
            columns[i][j] = tabledDataC1[row0][j];
            columns[i][numOfVariables] = 1;
        }
        row0++;
    }
    // For Class 2
    double[][] tabledDataC2 = new double[numOfInstances][numOfVariables];
    IntStream.range(0, tabledDataC2.length).parallel().forEach(i -> IntStream.range(0, tabledDataC2[0].length).parallel().forEach(ii -> tabledDataC2[i][ii] = 6 * mean + Std * random.nextGaussian()));
    int row1 = 0;
    for (int i = tabledDataC1.length + tabledDataC2.length; i < (tabledDataC2.length * 3); i++) {
        for (int j = 0; j < tabledDataC2[0].length; j++) {
            columns[i][j] = tabledDataC2[row1][j];
            columns[i][numOfVariables] = 2;
        }
        row1++;
    }
    // For Class 3
    double[][] tabledDataC3 = new double[numOfInstances][numOfVariables];
    IntStream.range(0, tabledDataC2.length).parallel().forEach(i -> IntStream.range(0, tabledDataC3[0].length).parallel().forEach(ii -> tabledDataC3[i][ii] = 9 * mean + Std * random.nextGaussian()));
    int row2 = 0;
    for (int i = tabledDataC1.length + tabledDataC2.length + tabledDataC3.length; i < (tabledDataC3.length * 4); i++) {
        for (int j = 0; j < tabledDataC3[0].length; j++) {
            columns[i][j] = tabledDataC3[row2][j];
            columns[i][numOfVariables] = 3;
        }
        row2++;
    }
    // Store all data as list of lists
    List<List<Double>> data = Arrays.stream(columns).
            map(feature -> Arrays.stream(feature).boxed().collect(Collectors.toList()))
            .collect(Collectors.toList());

    try (CSVWriter csvWriter = new CSVWriter(new FileWriter(path.toString()))) {
        String[] header = new String[data.get(0).size()];
        IntStream.range(0, header.length).forEach(i -> {
            if (i < header.length - 1)
                header[i] = "Feature" + i;
            else
                header[i] = "Class";
        });
        csvWriter.writeNext(header);
        for (List<Double> datum : data) {
            csvWriter.writeNext(datum.stream().map(String::valueOf).toArray(String[]::new));
        }
        csvWriter.flush();
    } catch (IOException exception) {
        exception.printStackTrace();
    }

and when run the Kmean, it throws this error unless I open the CSV file and rename the resonse label even with same name it works.

Exception in thread "main" java.lang.NumberFormatException: For input string: "0.0" at java.base/java.lang.NumberFormatException.forInputString(NumberFormatException.java:67) at java.base/java.lang.Integer.parseInt(Integer.java:668) at java.base/java.lang.Integer.parseInt(Integer.java:786) at org.tribuo.clustering.ClusteringFactory.generateOutput(ClusteringFactory.java:87) at org.tribuo.clustering.ClusteringFactory.generateOutput(ClusteringFactory.java:40) at org.tribuo.data.columnar.processors.response.FieldResponseProcessor.process(FieldResponseProcessor.java:226) at org.tribuo.data.columnar.RowProcessor.generateExample(RowProcessor.java:340) at org.tribuo.data.columnar.ColumnarDataSource$InnerIterator.hasNext(ColumnarDataSource.java:126) at org.tribuo.evaluation.TrainTestSplitter.(TrainTestSplitter.java:94) at Main.org.K_Means.main(K_Means.java:22)

Mohammed-Ryiad-Eiadeh commented 7 months ago

I don't know why is this problem happened. And for displaying the datapoints colored according to their corresponding clusters, I use this :

    // Now test each sample of the test part in order to display the discrepancy between ground truth and predictions
    var testDataURL = System.getProperty("user.dir") + "\\testData.csv";
    var dataToTest = new CSVLoader<>(new ClusteringFactory()).loadDataSource(Paths.get(testDataURL), "Label");
    var predictor = learner.predict(dataToTest);
    var doubles = DoubleColumn.create("newLabel", predictor.stream().mapToDouble(i -> Double.parseDouble(i.getOutput().toString())).toArray());
    var testedData = Table.read().csv(System.getProperty("user.dir") + "\\testData.csv");
    testedData.replaceColumn("Label", doubles);
    CsvWriteOptions writeOptions = CsvWriteOptions.builder(System.getProperty("user.dir") + "\\testData.csv").build();
    testedData.write().csv(writeOptions);
    var dataTable = Table.read().csv(System.getProperty("user.dir") + "\\testData.csv");
    Plot.show(ScatterPlot.create("Plot", dataTable, "F0", "F1", "newLabel"));

I know this is not efficient and I have to do this by fetching all required variables from the RAM. But this what I can do for now. newplot

Craigacp commented 7 months ago

String.valueOf(0.0) produces "0.0" which Integer.parseInt can't process. I assume when you load the file into something like Excel it's reformatting that column to be integer values on save, which is why it starts working again. So you should store the cluster labels as an int[] when generating the data, that should cause it to write out correctly.

Mohammed-Ryiad-Eiadeh commented 7 months ago
    // Store columns as list of lists
    List<List<Double>> listMain = IntStream.range(0, columns[0].length).mapToObj(i -> Arrays.stream(columns).
            map(column -> column[i]).collect(Collectors.toList())).collect(Collectors.toList());

    // Creat list of objects from DoubleColumn to prepare data for TableSaw Library
    List<Column<?>> doubleColumns = new ArrayList<>();  // here we use wildcard which is some generics stuff
    for (int i = 0; i < listMain.size() - 1; i++) {
        doubleColumns.add(DoubleColumn.create("F" + i, listMain.get(i)));
    }
    doubleColumns.add(IntColumn.create("Label", listMain.get(listMain.size() - 1).stream().mapToInt(Double::intValue).toArray()));
    Table table = Table.create("Data", new ArrayList<>(doubleColumns));
    table.write().csv(path.toString());

it works thank you thank you