aws / random-cut-forest-by-aws

An implementation of the Random Cut Forest data structure for sketching streaming data, with support for anomaly detection, density estimation, imputation, and more.
https://github.com/aws/random-cut-forest-by-aws
Apache License 2.0
206 stars 33 forks source link

Incorrect foreast cast result #391

Closed kaituo closed 1 year ago

kaituo commented 1 year ago

I wrote a test for RCFCaster:

package com.amazon.randomcutforest.examples.parkservices;

import java.io.BufferedWriter;
import java.io.FileWriter;
import java.io.IOException;
import java.util.Arrays;

import com.amazon.randomcutforest.config.Precision;
import com.amazon.randomcutforest.config.TransformMethod;
import com.amazon.randomcutforest.examples.Example;
import com.amazon.randomcutforest.parkservices.ForecastDescriptor;
import com.amazon.randomcutforest.parkservices.RCFCaster;
import com.amazon.randomcutforest.parkservices.calibration.Calibration;
import com.amazon.randomcutforest.returntypes.DiVector;
import com.amazon.randomcutforest.returntypes.RangeVector;
import com.amazon.randomcutforest.testutils.MultiDimDataWithKey;
import com.amazon.randomcutforest.testutils.ShingledMultiDimDataWithKeys;

public class RCFCasterExample2 implements Example {

    public static void main(String[] args) throws Exception {
        new RCFCasterExample2().run();
    }

    @Override
    public String command() {
        return "Calibrated RCFCast";
    }

    @Override
    public String description() {
        return "Calibrated RCFCast Example";
    }

    @Override
    public void run() throws Exception {
        // Create and populate a random cut forest

        int numberOfTrees = 30;
        int sampleSize = 256;
        Precision precision = Precision.FLOAT_32;

        // Multi attribute forecasting is less understood than singe attribute
        // forecasting;
        // it is not always clear or easy to decide if multi-attribute forecasting is
        // reasonable
        // but the code below will run for multi-attribute case.
        int baseDimensions = 1;
        int forecastHorizon = 24;
        int shingleSize = 8;
        int outputAfter = 32;

        long seed = 2023L;

        double[][] fulldata =
                {
                        {6780.0},
                        {11175.0},
                        {9924.0},
                        {6173.0},
                        {11685.0},
                        {9244.0},
                        {13095.0},
                        {14248.0},
                        {10481.0},
                        {9462.0},
                        {14882.0},
                        {8794.0},
                        {7979.0},
                        {7674.0},
                        {18300.0},
                        {9985.0},
                        {7772.0},
                        {8216.0},
                        {11025.0},
                        {7903.0},
                        {7748.0},
                        {12291.0},
                        {10015.0},
                        {15658.0},
                        {13253.0},
                        {11362.0},
                        {13895.0},
                        {9050.0},
                        {14327.0},
                        {8944.0},
                        {13687.0},
                        {8761.0},
                        {5426.0},
                        {8754.0},
                        {8757.0},
                        {5571.0},
                        {10897.0},
                        {8840.0},
                        {12298.0},
                        {7259.0},
                        {4782.0}
                    };

        int dimensions = baseDimensions * shingleSize;
        // change this line to try other transforms; but the default is NORMALIZE
        // uncomment the transformMethod() below
        TransformMethod transformMethod = TransformMethod.NORMALIZE;
        RCFCaster caster = RCFCaster.builder().dimensions(dimensions).randomSeed(seed + 1).numberOfTrees(numberOfTrees)
                .shingleSize(shingleSize).sampleSize(sampleSize).internalShinglingEnabled(true).precision(precision)
                .anomalyRate(0.01).outputAfter(outputAfter).calibration(Calibration.MINIMAL)
                // the following affects the moving average in many of the transformations
                // the 0.02 corresponds to a half life of 1/0.02 = 50 observations
                // this is different from the timeDecay() of RCF; however it is a similar
                // concept
                .transformDecay(0.02).forecastHorizon(forecastHorizon).initialAcceptFraction(0.125).build();

        String name = "example";
        BufferedWriter file = new BufferedWriter(new FileWriter(name));

        for (int j = 0; j < fulldata.length; j++) {
            file.append(j + " ");
            for (int k = 0; k < baseDimensions; k++) {
                file.append(fulldata[j][k] + " ");
            }
            file.append("\n");
        }
        file.append("\n");
        file.append("\n");

        for (int j = 0; j < fulldata.length; j++) {
            ForecastDescriptor result = caster.process(fulldata[j], 0L);
            if (j == fulldata.length - 1) {
                printResult(file, result, j, baseDimensions);
            }
        }
        file.close();

    }

    void printResult(BufferedWriter file, ForecastDescriptor result, int current, int inputLength) throws IOException {
        RangeVector forecast = result.getTimedForecast().rangeVector;
        float[] errorP50 = result.getObservedErrorDistribution().values;
        float[] upperError = result.getObservedErrorDistribution().upper;
        float[] lowerError = result.getObservedErrorDistribution().lower;
        DiVector rmse = result.getErrorRMSE();
        float[] mean = result.getErrorMean();
        float[] calibration = result.getCalibration();

        file.append(current + " " + 1000 + "\n");
        file.append("\n");
        file.append("\n");

        // block corresponding to the past; print the errors
        for (int i = forecast.values.length / inputLength - 1; i >= 0; i--) {
            file.append((current - i) + " ");
            for (int j = 0; j < inputLength; j++) {
                int k = i * inputLength + j;
                file.append(mean[k] + " " + rmse.high[k] + " " + rmse.low[k] + " " + errorP50[k] + " " + upperError[k]
                        + " " + lowerError[k] + " " + calibration[k] + " ");
            }
            file.append("\n");
        }
        file.append("\n");
        file.append("\n");

        // block corresponding to the future; the projections and the projected errors
        for (int i = 0; i < forecast.values.length / inputLength; i++) {
            file.append((current + i) + " ");
            for (int j = 0; j < inputLength; j++) {
                int k = i * inputLength + j;
                file.append(forecast.values[k] + " " + forecast.upper[k] + " " + forecast.lower[k] + " ");
            }
            file.append("\n");
        }
        file.append("\n");
        file.append("\n");
    }

}

Here is the output:

0 6780.0 
1 11175.0 
2 9924.0 
3 6173.0 
4 11685.0 
5 9244.0 
6 13095.0 
7 14248.0 
8 10481.0 
9 9462.0 
10 14882.0 
11 8794.0 
12 7979.0 
13 7674.0 
14 18300.0 
15 9985.0 
16 7772.0 
17 8216.0 
18 11025.0 
19 7903.0 
20 7748.0 
21 12291.0 
22 10015.0 
23 15658.0 
24 13253.0 
25 11362.0 
26 13895.0 
27 9050.0 
28 14327.0 
29 8944.0 
30 13687.0 
31 8761.0 
32 5426.0 
33 8754.0 
34 8757.0 
35 5571.0 
36 10897.0 
37 8840.0 
38 12298.0 
39 7259.0 
40 4782.0 

40 1000

17 0.0 0.0 0.0 0.0 3775.3728 -3775.3728 0.0 
18 0.0 0.0 0.0 0.0 3775.3728 -3775.3728 0.0 
19 0.0 0.0 0.0 0.0 3775.3728 -3775.3728 0.0 
20 0.0 0.0 0.0 0.0 3775.3728 -3775.3728 0.0 
21 0.0 0.0 0.0 0.0 3775.3728 -3775.3728 0.0 
22 0.0 0.0 0.0 0.0 3775.3728 -3775.3728 0.0 
23 0.0 0.0 0.0 0.0 3775.3728 -3775.3728 0.0 
24 0.0 0.0 0.0 0.0 3775.3728 -3775.3728 0.0 
25 0.0 0.0 0.0 0.0 3775.3728 -3775.3728 0.0 
26 0.0 0.0 0.0 0.0 3775.3728 -3775.3728 0.0 
27 0.0 0.0 0.0 0.0 3775.3728 -3775.3728 0.0 
28 0.0 0.0 0.0 0.0 3775.3728 -3775.3728 0.0 
29 0.0 0.0 0.0 0.0 3775.3728 -3775.3728 0.0 
30 0.0 0.0 0.0 0.0 3775.3728 -3775.3728 0.0 
31 0.0 0.0 0.0 0.0 3775.3728 -3775.3728 0.0 
32 0.0 0.0 0.0 0.0 3775.3728 -3775.3728 0.0 
33 0.0 0.0 0.0 0.0 3775.3728 -3775.3728 0.0 
34 0.0 0.0 0.0 0.0 3775.3728 -3775.3728 0.0 
35 0.0 0.0 0.0 0.0 3775.3728 -3775.3728 0.0 
36 0.0 0.0 0.0 0.0 3775.3728 -3775.3728 0.0 
37 0.0 0.0 0.0 0.0 3775.3728 -3775.3728 0.0 
38 0.0 0.0 0.0 0.0 3775.3728 -3775.3728 0.0 
39 0.0 0.0 0.0 0.0 3775.3728 -3775.3728 0.0 
40 284943.97 284943.96875 0.0 0.0 31892.232 25096.56 0.0 

40 300538.22 332430.44 300538.22 
41 -280628.22 -276852.84 -284403.6 
42 -280768.06 -276992.7 -284543.44 
43 -280907.9 -277132.53 -284683.28 
44 -281047.75 -277272.38 -284823.12 
45 -281187.62 -277412.25 -284963.0 
46 299699.1 303474.47 295923.72 
47 -281467.3 -277691.94 -285242.7 
48 299419.4 303194.78 295644.03 
49 5715.6978 9491.07 1940.325 
50 299139.7 302915.06 295364.3 
51 -282026.72 -278251.34 -285802.1 
52 -282166.56 -278391.2 -285941.94 
53 -282306.4 -278531.03 -286081.78 
54 -282446.28 -278670.9 -286221.66 
55 -282586.12 -278810.75 -286361.5 
56 298300.6 302075.97 294525.22 
57 -282865.8 -279090.44 -286641.2 
58 298020.88 301796.25 294245.5 
59 4317.1865 8092.5596 541.8137 
60 297741.2 301516.56 293965.8 
61 -283425.22 -279649.84 -287200.6 
62 -283565.1 -279789.72 -287340.47 
63 -283704.94 -279929.56 -287480.3 

In the training data, there is no negative value. But the forecasted values have a lot of negatives.

sudiptoguha commented 1 year ago

Thanks for posting the issue -- this is a bug due to line 105 in https://github.com/aws/random-cut-forest-by-aws/blob/main/Java/parkservices/src/main/java/com/amazon/randomcutforest/parkservices/preprocessor/InitialSegmentPreprocessor.java

Instead of

tempList[j + 2 * inputLength].update(initialValues[i][j] - tempList[j].getMean());

the correct line should be

tempList[j + 2 * inputLength].update( tempList[j].getMean());

The original produces an incorrect transformation with the negatives. After the change the output looks like

0 6780.0 1 11175.0 2 9924.0 3 6173.0 4 11685.0 5 9244.0 6 13095.0 7 14248.0 8 10481.0 9 9462.0 10 14882.0 11 8794.0 12 7979.0 13 7674.0 14 18300.0 15 9985.0 16 7772.0 17 8216.0 18 11025.0 19 7903.0 20 7748.0 21 12291.0 22 10015.0 23 15658.0 24 13253.0 25 11362.0 26 13895.0 27 9050.0 28 14327.0 29 8944.0 30 13687.0 31 8761.0 32 5426.0 33 8754.0 34 8757.0 35 5571.0 36 10897.0 37 8840.0 38 12298.0 39 7259.0 40 4782.0

40 1000

17 0.0 0.0 0.0 0.0 3776.2095 -3776.2095 0.0 18 0.0 0.0 0.0 0.0 3776.2095 -3776.2095 0.0 19 0.0 0.0 0.0 0.0 3776.2095 -3776.2095 0.0 20 0.0 0.0 0.0 0.0 3776.2095 -3776.2095 0.0 21 0.0 0.0 0.0 0.0 3776.2095 -3776.2095 0.0 22 0.0 0.0 0.0 0.0 3776.2095 -3776.2095 0.0 23 0.0 0.0 0.0 0.0 3776.2095 -3776.2095 0.0 24 0.0 0.0 0.0 0.0 3776.2095 -3776.2095 0.0 25 0.0 0.0 0.0 0.0 3776.2095 -3776.2095 0.0 26 0.0 0.0 0.0 0.0 3776.2095 -3776.2095 0.0 27 0.0 0.0 0.0 0.0 3776.2095 -3776.2095 0.0 28 0.0 0.0 0.0 0.0 3776.2095 -3776.2095 0.0 29 0.0 0.0 0.0 0.0 3776.2095 -3776.2095 0.0 30 0.0 0.0 0.0 0.0 3776.2095 -3776.2095 0.0 31 0.0 0.0 0.0 0.0 3776.2095 -3776.2095 0.0 32 0.0 0.0 0.0 0.0 3776.2095 -3776.2095 0.0 33 0.0 0.0 0.0 0.0 3776.2095 -3776.2095 0.0 34 0.0 0.0 0.0 0.0 3776.2095 -3776.2095 0.0 35 0.0 0.0 0.0 0.0 3776.2095 -3776.2095 0.0 36 0.0 0.0 0.0 0.0 3776.2095 -3776.2095 0.0 37 0.0 0.0 0.0 0.0 3776.2095 -3776.2095 0.0 38 0.0 0.0 0.0 0.0 3776.2095 -3776.2095 0.0 39 0.0 0.0 0.0 0.0 3776.2095 -3776.2095 0.0 40 0.0 0.0 0.0 0.0 3776.2095 -3776.2095 0.0

40 10073.035 13849.244 6296.8257 41 9500.144 13276.354 5723.934 42 9307.606 13083.816 5531.397 43 9191.14 12967.35 5414.93 44 9382.338 13158.547 5606.1284 45 8911.4375 12687.646 5135.228 46 8946.831 12723.041 5170.6216 47 8806.98 12583.189 5030.771 48 8667.129 12443.338 4890.9194 49 8339.072 12115.281 4562.863 50 8543.231 12319.441 4767.022 51 8072.33 11848.539 4296.1206 52 8107.7246 11883.934 4331.515 53 7967.873 11744.082 4191.6636 54 7828.022 11604.231 4051.8125 55 7499.9653 11276.175 3723.7559 56 7704.124 11480.334 3927.9146 57 7233.2236 11009.434 3457.0142 58 7268.617 11044.826 3492.4077 59 7128.766 10904.976 3352.5566 60 6988.915 10765.125 3212.7056 61 6660.8584 10437.068 2884.649 62 6865.017 10641.227 3088.8076 63 6394.116 10170.326 2617.9067

This may also be related to issue #390, which is also producing artifacts in anomaly detection. But just this fix does not seem to completely solve #390, so likely there are other issues as well. PR coming shortly.

sudiptoguha commented 1 year ago

Resolved.