deepjavalibrary / djl

An Engine-Agnostic Deep Learning Framework in Java
https://djl.ai
Apache License 2.0
3.96k stars 630 forks source link

TimeSeries API Bugs (frequency, context length, FEAT_DYNAMIC_REAL) #3271

Open keklol5050 opened 2 weeks ago

keklol5050 commented 2 weeks ago

If we specify the frequency as required by the Lag class or as the GluonTS documentation says, image for example "1H", "H", "15min", we will get an exception:

Exception in thread "main" java.time.format.DateTimeParseException: Text cannot be parsed to a Duration
    at java.base/java.time.Duration.parse(Duration.java:419)
    at ai.djl.timeseries.transform.feature.Feature.addTimeFeature(Feature.java:127)
    at ai.djl.timeseries.transform.feature.Feature.addTimeFeature(Feature.java:75)
    at ai.djl.timeseries.transform.feature.AddTimeFeature.transform(AddTimeFeature.java:63)
    at ai.djl.timeseries.dataset.TimeSeriesDataset.apply(TimeSeriesDataset.java:105)
    at ai.djl.timeseries.dataset.TimeSeriesDataset.get(TimeSeriesDataset.java:60)
    at ai.djl.training.dataset.DataIterable.fetch(DataIterable.java:170)
    at ai.djl.training.dataset.DataIterable.next(DataIterable.java:145)
    at ai.djl.training.dataset.DataIterable.next(DataIterable.java:43)
    at ai.djl.training.EasyTrain.fit(EasyTrain.java:54)

or

Exception in thread "main" java.time.format.DateTimeParseException: Text cannot be parsed to a Period
    at java.base/java.time.Period.parse(Period.java:349)
    at ai.djl.timeseries.transform.feature.Feature.addTimeFeature(Feature.java:129)
    at ai.djl.timeseries.transform.feature.Feature.addTimeFeature(Feature.java:75)
    at ai.djl.timeseries.transform.feature.AddTimeFeature.transform(AddTimeFeature.java:63)
    at ai.djl.timeseries.dataset.TimeSeriesDataset.apply(TimeSeriesDataset.java:105)
    at ai.djl.timeseries.dataset.TimeSeriesDataset.get(TimeSeriesDataset.java:60)
    at ai.djl.training.dataset.DataIterable.fetch(DataIterable.java:170)
    at ai.djl.training.dataset.DataIterable.next(DataIterable.java:145)
    at ai.djl.training.dataset.DataIterable.next(DataIterable.java:43)
    at ai.djl.training.EasyTrain.fit(EasyTrain.java:54)

if we specify the frequency in the format required by the Period and Duration classes (for example "1h", "5h") we will receive an exception from DJL:

Exception in thread "main" java.lang.IllegalArgumentException: invalid frequency
    at ai.djl.timeseries.timefeature.Lag.getLagsForFreq(Lag.java:90)
    at ai.djl.timeseries.timefeature.Lag.getLagsForFreq(Lag.java:118)
    at ai.djl.timeseries.model.deepar.DeepARNetwork.<init>(DeepARNetwork.java:127)
    at ai.djl.timeseries.model.deepar.DeepARTrainingNetwork.<init>(DeepARTrainingNetwork.java:26)
    at ai.djl.timeseries.model.deepar.DeepARNetwork$Builder.buildTrainingNetwork(DeepARNetwork.java:608)

if we specify the frequency like "15M" = it will be 15 months, so with weeks and months everything is ok

keklol5050 commented 2 weeks ago

also, in gluonTS cardinality can be "auto" or "ignore", if we dont use categorical features, but TimeSeries API docs doesnt indicate how can we do the same thing here

keklol5050 commented 2 weeks ago

also, in gluonTS we can make context length and prediction length different, but when i try to do it in TimeSeries API i get exception in any case, regardless of the length of the time series:

Exception in thread "main" ai.djl.translate.TranslateException: java.lang.IllegalArgumentException: lags cannot go further than prior sequence length, found lag 157 while prior sequence is only 156-long
    at ai.djl.inference.Predictor.batchPredict(Predictor.java:196)
    at com.crypto.analysis.main.ex.dar.TrainTimeSeries.predict(TrainTimeSeries.java:154)
    at com.crypto.analysis.main.ex.dar.TrainTimeSeries.main(TrainTimeSeries.java:64)
Caused by: java.lang.IllegalArgumentException: lags cannot go further than prior sequence length, found lag 157 while prior sequence is only 156-long
    at ai.djl.timeseries.model.deepar.DeepARNetwork.laggedSequenceValues(DeepARNetwork.java:267)
    at ai.djl.timeseries.model.deepar.DeepARNetwork.unrollLaggedRnn(DeepARNetwork.java:233)
    at ai.djl.timeseries.model.deepar.DeepARPredictionNetwork.forwardInternal(DeepARPredictionNetwork.java:47)
    at ai.djl.nn.AbstractBaseBlock.forward(AbstractBaseBlock.java:79)
    at ai.djl.nn.Block.forward(Block.java:127)
    at ai.djl.inference.Predictor.predictInternal(Predictor.java:146)
    at ai.djl.inference.Predictor.batchPredict(Predictor.java:187)
    ... 2 more

that is TrainTimeSeries.class from examples with some changes:

M5Forecast.Builder builder =
                M5Forecast.builder()
                        .optUsage(usage)
                        .optRepository(BasicDatasets.REPOSITORY)
                        .optGroupId(BasicDatasets.GROUP_ID)
                        .optArtifactId("m5forecast-unittest")
                        .setTransformation(transformation)
                        .setContextLength(8) // here changed
                        .setSampling(32, usage == Dataset.Usage.TRAIN);

DeepARNetwork.Builder builder =
                DeepARNetwork.builder()
                        .setCardinality(cardinality)
                        .setFreq(freq)
                        .setPredictionLength(4)// here changed
                        .optContextLength(8)// here changed
                        .optDistrOutput(distributionOutput)
                        .optUseFeatStaticCat(true);

Map<String, Object> arguments = new ConcurrentHashMap<>();
            arguments.put("prediction_length", 4);
            arguments.put("context_length", 8; // added contextlength param

inputShapes[6] = new Shape(1, 8, TimeFeature.timeFeaturesFromFreqStr(freq).size() + 1);
                inputShapes[7] = new Shape(1, 8);
                inputShapes[8] = new Shape(1, 8);
keklol5050 commented 1 week ago

image i think problem with frequency is here, bec for Duration.class we need PT, not only P, like PT1H

and this is my simple solution for this problem if anyone have the same:

    public void patchAddTimeFeature() {
        try {
            ClassPool pool = ClassPool.getDefault();
            CtClass cc = pool.get("ai.djl.timeseries.transform.feature.Feature");

            CtMethod method = cc.getDeclaredMethod("addTimeFeature", new CtClass[]{
                    pool.get("ai.djl.ndarray.NDManager"),
                    pool.get("ai.djl.timeseries.dataset.FieldName"),
                    pool.get("ai.djl.timeseries.dataset.FieldName"),
                    pool.get("ai.djl.timeseries.dataset.FieldName"),
                    pool.get("java.util.List"),
                    CtClass.intType,
                    pool.get("java.lang.String"),
                    pool.get("ai.djl.timeseries.TimeSeriesData"),
                    CtClass.booleanType
            });

            if (tf.getMinuteCount() >= 60 && tf.getMinuteCount() < 60*24) { // tf means TimeFrame like 1h,15min,3.5h, etc., and getMinuteCount() returns the tf minutes count, 1h=60, 4h=240, etc.
                method.insertAt(121, "if (freq.endsWith(\"H\") || freq.endsWith(\"T\") || freq.endsWith(\"S\")) { sb.insert(0, \"T\"); }");
            } else if (tf.getMinuteCount() < 60) {
                method.insertAt(124, "formattedFreq = new StringBuilder(freq.replace(\"min\", \"m\").toUpperCase()).insert(0, \"T\").insert(0, \"P\").toString(); freq=\"1H\"; ");
            }
            method.insertAt(132, "System.out.println(timeFreq);");

            cc.toClass();
        } catch (NotFoundException | CannotCompileException e) {
            throw new RuntimeException(e);
        }
    }

also there are same prob with minutes

keklol5050 commented 1 week ago

also got a strange exception when

                        .optUseFeatStaticCat(false)
                        .optUseFeatStaticReal(false)
                        .optUseFeatDynamicReal(true)

                   ...
    @Override
        public TimeSeriesData getTimeSeriesData(NDManager manager, long index) {
            float[][] ts = timeSeries.get((int) index);

            TimeSeriesData data = new TimeSeriesData(ts.length);

            data.add(FieldName.TARGET, manager.create(ts[targetPosition]));

            for (int i = 0; i < ts.length; i++) {
                if (targetPosition.contains(i)) continue;
                data.add(FieldName.FEAT_DYNAMIC_REAL, manager.create(ts[i]));
            }

            data.setStartTime(dates.get((int) index).toInstant().atZone(ZoneId.of("UTC+0")).toLocalDateTime());
            return data;
        }
Exception in thread "main" java.lang.IllegalArgumentException: FEAT_STATIC_REAL don't map to any NDArray
    at ai.djl.timeseries.transform.convert.Convert.asArray(Convert.java:89)
    at ai.djl.timeseries.transform.convert.AsArray.transform(AsArray.java:56)
    at ai.djl.timeseries.dataset.TimeSeriesDataset.apply(TimeSeriesDataset.java:105)
    at ai.djl.timeseries.dataset.TimeSeriesDataset.get(TimeSeriesDataset.java:60)
    at ai.djl.training.dataset.DataIterable.fetch(DataIterable.java:170)
    at ai.djl.training.dataset.DataIterable.next(DataIterable.java:145)
    at ai.djl.training.dataset.DataIterable.next(DataIterable.java:43)
    at ai.djl.training.EasyTrain.fit(EasyTrain.java:54)

in gluonTS feat_dynamic_real can be used without static_real

keklol5050 commented 1 week ago

@frankfliu problem with lags here https://github.com/deepjavalibrary/djl/blob/master/extensions/timeseries/src/main/java/ai/djl/timeseries/translator/BaseTimeSeriesTranslator.java image

keklol5050 commented 1 week ago

@frankfliu also DeepARNetwork class has a few strange lines of code like TRAIN_INPUT_FIELDS or this image