deepjavalibrary / djl

An Engine-Agnostic Deep Learning Framework in Java
https://djl.ai
Apache License 2.0
4.03k stars 642 forks source link

Contribute to an example for training AirPassengers model #3063

Open lordofthejars opened 4 months ago

lordofthejars commented 4 months ago

Description

Currently, in docs for time series forecasting, the air passengers model is used to show how to inference a time series model using DJL. However, the model is already trained. Then, it shows Walmart's selling example and how to train and use that model.

It will also be nice to show how to train the airpassengers model, as it is a more straightforward time series than the other one.

I created the code which can be seen here https://github.com/lordofthejars/airpassengers-forecast

I also put the main class here:

public static TrainingResult runExample(String[] arguments) throws IOException, TranslateException {

        try (Model model = Model.newInstance("deepar")) {
            // specify the model distribution output, for M5 case, NegativeBinomial best describe it
            DistributionOutput distributionOutput = new NegativeBinomialOutput();
            DefaultTrainingConfig config = setupTrainingConfig("output/model", 1, distributionOutput);

            NDManager manager = model.getNDManager();
            DeepARNetwork trainingNetwork = getDeepARModel("M", 12, distributionOutput, true);
            model.setBlock(trainingNetwork);

             List<TimeSeriesTransform> trainingTransformation =
                    trainingNetwork.createTrainingTransformation(manager);
            int contextLength = trainingNetwork.getContextLength();

            M5ForecastAirPredictionDataset trainSet =
                    getDataset(trainingTransformation, contextLength, Dataset.Usage.TRAIN);

            try ( Trainer trainer = model.newTrainer(config)) {
                trainer.setMetrics(new Metrics());

                System.out.println("+++++" + trainSet.availableSize());
                Shape shape = new Shape(1, trainSet.availableSize());
                trainer.initialize(shape);

                EasyTrain.fit(trainer, 5, trainSet, null);
                return trainer.getTrainingResult();
            }

        }

    }

    private static DeepARNetwork getDeepARModel(String freq, int predictionLength,
            DistributionOutput distributionOutput, boolean training) {
        // here is feat_static_cat's cardinality which depend on your dataset, change to what need

        List<Integer> cardinality = new ArrayList<>();
        cardinality.add(144 - 32);

        DeepARNetwork.Builder builder =
                DeepARNetwork.builder()
                        .setFreq(freq)
                        .setPredictionLength(predictionLength)
                        .setCardinality(cardinality)
                        .optDistrOutput(distributionOutput)
                        .optUseFeatStaticCat(false);
        return training ? builder.buildTrainingNetwork() : builder.buildPredictionNetwork();
    }

    private static DefaultTrainingConfig setupTrainingConfig(
            String outputDir, int maxGpu, DistributionOutput distributionOutput) {
        SaveModelTrainingListener listener = new SaveModelTrainingListener(outputDir);
        listener.setSaveModelCallback(
                trainer -> {
                    TrainingResult result = trainer.getTrainingResult();
                    Model model = trainer.getModel();
                    float rmsse = result.getValidateEvaluation("RMSSE");
                    model.setProperty("RMSSE", String.format("%.5f", rmsse));
                    model.setProperty("Loss", String.format("%.5f", result.getValidateLoss()));
                });

        return new DefaultTrainingConfig(new DistributionLoss("Loss", distributionOutput))
                .addEvaluator(new Rmsse(distributionOutput))
                .optDevices(Engine.getInstance().getDevices(maxGpu))
                .optInitializer(new XavierInitializer(), Parameter.Type.WEIGHT)
                .addTrainingListeners(TrainingListener.Defaults.logging(outputDir))
                .addTrainingListeners(listener);
    }

    private static M5ForecastAirPredictionDataset getDataset(
            List<TimeSeriesTransform> transformation, int contextLength, Dataset.Usage usage)
            throws IOException {

            M5ForecastAirPredictionDataset.Builder builder = 
                M5ForecastAirPredictionDataset.builder()
                .datasetJson(URI.create("https://resources.djl.ai/test-models/mxnet/timeseries/air_passengers.json").toURL())
                .train(usage == Dataset.Usage.TRAIN)
                .setTransformation(transformation)
                .setContextLength(contextLength)
                .setSampling(32, usage == Dataset.Usage.TRAIN);

                return builder.build();

    }

    private static final class AirPassengers {

        Date start;
        float[] target;
    }

The problem is that when I ran the method I got:

java.lang.ArrayIndexOutOfBoundsException: Index 3 out of bounds for length 1
    at ai.djl.timeseries.model.deepar.DeepARNetwork.initializeChildBlocks(DeepARNetwork.java:167)
    at ai.djl.nn.AbstractBaseBlock.initialize(AbstractBaseBlock.java:187)
    at ai.djl.training.Trainer.initialize(Trainer.java:117)
    at org.acme.AirPassengerPrediction.runExample(AirPassengerPrediction.java:97)
    at org.acme.AirPassengerPrediction.run(AirPassengerPrediction.java:69)
    at org.acme.AirPassengerPrediction_ClientProxy.run(Unknown Source)
    at io.quarkus.runtime.ApplicationLifecycleManager.run(ApplicationLifecycleManager.java:132)
    at io.quarkus.runtime.Quarkus.run(Quarkus.java:71)
    at io.quarkus.runtime.Quarkus.run(Quarkus.java:44)
    at io.quarkus.runner.GeneratedMain.main(Unknown Source)
    at java.base/jdk.internal.reflect.DirectMethodHandleAccessor.invoke(DirectMethodHandleAccessor.java:103)
    at java.base/java.lang.reflect.Method.invoke(Method.java:580)
    at io.quarkus.runner.bootstrap.StartupActionImpl$1.run(StartupActionImpl.java:113)
    at java.base/java.lang.Thread.run(Thread.java:1583)
    Suppressed: java.lang.NullPointerException: Cannot invoke "java.lang.Float.floatValue()" because the return value of "ai.djl.training.TrainingResult.getValidateEvaluation(String)" is null
        at org.acme.AirPassengerPrediction.lambda$0(AirPassengerPrediction.java:131)
        at ai.djl.training.listener.SaveModelTrainingListener.saveModel(SaveModelTrainingListener.java:151)
        at ai.djl.training.listener.SaveModelTrainingListener.onTrainingEnd(SaveModelTrainingListener.java:90)
        at ai.djl.training.Trainer.lambda$close$2(Trainer.java:330)
        at java.base/java.util.ArrayList.forEach(ArrayList.java:1596)
        at ai.djl.training.Trainer.notifyListeners(Trainer.java:284)
        at ai.djl.training.Trainer.close(Trainer.java:330)
        at org.acme.AirPassengerPrediction.runExample(AirPassengerPrediction.java:101)
        at org.acme.AirPassengerPrediction.run(AirPassengerPrediction.java:69)
        ... 9 more

I think the parameters are correct, but I have missed or set one incorrectly.

If you can help with this, I can clean the code and then send a PR to the examples folder.

Thank you very much.

Will this change the current api? How?

No

Who will benefit from this enhancement?

Examples provided to all users

References

lordofthejars commented 4 months ago

I've updated the code, now the code generates a model but it gives really bad results in evaluation. Now it is time to modelling the parameters but I might appreciate some help