deepjavalibrary / djl

An Engine-Agnostic Deep Learning Framework in Java
https://djl.ai
Apache License 2.0
4.13k stars 656 forks source link

Failed to do timeseries inference with Gluon-TS trained model #2106

Closed YRChen1998 closed 2 years ago

YRChen1998 commented 2 years ago

Description

I tried to use DJL to do forecast with a demo model trainded by Gluon-TS, but there is something wrong. Here is my Gluon-TS code:

from gluonts.dataset.repository.datasets import get_dataset
from gluonts.mx import DeepAREstimator
from gluonts.mx import Trainer
from pathlib import Path
from gluonts.mx.model import predictor as mxPredictor

dataset = get_dataset("airpassengers")

deepar = DeepAREstimator(prediction_length=12, freq="M", trainer=Trainer(epochs=5))
model = deepar.train(dataset.train)
symbol_predictor:mxPredictor.SymbolBlockPredictor = model.as_symbol_block_predictor(dataset=dataset.train)
symbol_predictor.serialize(Path("/Users/yrchen/Desktop/GluonTS_Projects/test/DeepAR"))

Then I got the model files and put them to DJL Directory: image

But there is something wrong when I ran the DJL inference codes which is in [timeseries] add m5 demo and a simple demo #2055

package org.example;

import ai.djl.ModelException;
import ai.djl.inference.Predictor;
import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.NDList;
import ai.djl.ndarray.NDManager;
import ai.djl.repository.zoo.Criteria;
import ai.djl.repository.zoo.ZooModel;
import ai.djl.timeseries.Forecast;
import ai.djl.timeseries.SampleForecast;
import ai.djl.timeseries.TimeSeriesData;
import ai.djl.timeseries.dataset.FieldName;
import ai.djl.timeseries.translator.DeepARTranslator;
import ai.djl.training.util.ProgressBar;
import ai.djl.translate.TranslateException;

import com.google.gson.GsonBuilder;

import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import java.io.IOException;
import java.io.InputStreamReader;
import java.io.OutputStream;
import java.io.Reader;
import java.net.URL;
import java.nio.charset.StandardCharsets;
import java.nio.file.Files;
import java.nio.file.Path;
import java.nio.file.Paths;
import java.time.LocalDateTime;
import java.time.ZoneId;
import java.util.Arrays;
import java.util.Date;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;

public final class TimeSeriesAirPassengers {

    private static final Logger logger = LoggerFactory.getLogger(TimeSeriesAirPassengers.class);

    private TimeSeriesAirPassengers() {}

    public static void main(String[] args) throws IOException, TranslateException, ModelException {

        float[] results = predict();
        logger.info(Arrays.toString(results));
    }

    public static float[] predict() throws IOException, TranslateException, ModelException {

        Map<String, Object> arguments = new ConcurrentHashMap<>();
        arguments.put("prediction_length", 12);
        arguments.put("freq", "M");
        arguments.put("use_" + FieldName.FEAT_DYNAMIC_REAL.name().toLowerCase(), false);
        arguments.put("use_" + FieldName.FEAT_STATIC_CAT.name().toLowerCase(), false);
        arguments.put("use_" + FieldName.FEAT_STATIC_REAL.name().toLowerCase(), false);

        DeepARTranslator translator = DeepARTranslator.builder(arguments).build();

        Criteria<TimeSeriesData, Forecast> criteria =
                Criteria.builder()
                        .setTypes(TimeSeriesData.class, Forecast.class)
                        .optModelPath(Paths.get("model/DeepAR"))
                        .optModelName("prediction_net")
                        .optTranslator(translator)
                        .optProgress(new ProgressBar())
                        .build();

        try (ZooModel<TimeSeriesData, Forecast> model = criteria.loadModel();
             Predictor<TimeSeriesData, Forecast> predictor = model.newPredictor()) {
            NDManager manager = model.getNDManager();

            Path dataFile = Paths.get("src/test/resources/air_passengers.json");
            AirPassengers ap = new AirPassengers(dataFile);
            TimeSeriesData data = ap.get(manager);

            // save data for plotting
            NDArray target = data.get(FieldName.TARGET);
            target.setName("target");
            saveNDArray(target, Paths.get("./target.zip"));

            Forecast forecast = predictor.predict(data);

            // save data for plotting. Please see the corresponding python script from
            // https://gist.github.com/Carkham/a5162c9298bc51fec648a458a3437008
            NDArray samples = ((SampleForecast) forecast).getSortedSamples();
            samples.setName("samples");
            saveNDArray(samples, Paths.get("./samples.zip"));
            return forecast.mean().toFloatArray();
        }
    }

    public static void saveNDArray(NDArray array, Path path) throws IOException {
        try (OutputStream os = Files.newOutputStream(path)) {
            new NDList(new NDList(array)).encode(os, true);
        }
    }

    public static class AirPassengers {

        private Path path;
        private AirPassengerData data;

        public AirPassengers(Path path) {
            this.path = path;
            prepare();
        }

        public TimeSeriesData get(NDManager manager) {
            LocalDateTime start =
                    data.start.toInstant().atZone(ZoneId.systemDefault()).toLocalDateTime();
            NDArray target = manager.create(data.target);
            TimeSeriesData ret = new TimeSeriesData(10);
            ret.setStartTime(start);
            ret.setField(FieldName.TARGET, target);
            return ret;
        }

        private void prepare() {
            try {
                URL url = path.toUri().toURL();
                try (Reader reader =
                             new InputStreamReader(url.openStream(), StandardCharsets.UTF_8)) {
                    data =
                            new GsonBuilder()
                                    .setDateFormat("yyyy-MM")
                                    .create()
                                    .fromJson(reader, AirPassengerData.class);
                }
            } catch (IOException e) {
                throw new IllegalArgumentException("Invalid url: " + path, e);
            }
        }

        private static class AirPassengerData {
            Date start;
            float[] target;
        }
    }
}

Error Message

Caused by: ai.djl.engine.EngineException: MXNet engine call failed: MXNetError: Error in operator deeparpredictionnetwork0_lstm0_t0_plus0: [14:13:50] /Users/runner/work/djl/djl/src/ndarray/../operator/tensor/../elemwise_op_common.h:134: Check failed: assign(&dattr, vec.at(i)): Incompatible attr in node deeparpredictionnetwork0_lstm0_t0_plus0 at 1-th input: expected [1,160], got [0,160]

What have you tried to solve it?

  1. I tried to switch mxnet versions and DJL mxnet engines but failed.
  2. I also refered the solution in issue: Got some problems when loading a model trained by gluonts #1127 but it didn't work.

Environment Info

implementation("ai.djl.mxnet:mxnet-native-mkl:1.9.1:osx-x86_64")
implementation("ai.djl.mxnet:mxnet-model-zoo:0.19.0")
implementation("ai.djl:basicdataset:0.19.0")
implementation("ai.djl.timeseries:timeseries:0.19.0")

I want to take gluon-ts trained models to production, so I think DJL is a good choice. Are there some solutions to solve the problem?

frankfliu commented 2 years ago

Please try latest PR: https://github.com/deepjavalibrary/djl/pull/2027

Carkham commented 2 years ago

Thanks, This bug comes from MXNET's rnn.begin_state. Here's a temporary solution: in the prediction_net-symbol.json model file, you can manually change the batch_size of every begin_state operation's shape to -1. It should be like

{
  "op": "_zeros", 
  "name": "deeparpredictionnetwork0_deeparpredictionnetwork0_lstm0_begin_state_0", 
  "attrs": {
    "__layout__": "NC", 
    "dtype": "float32", 
    "shape": "(0, 40)" // change every begin_state's shape to (-1, 40)
  }, 
  "inputs": []
}

You can also directly download the model from our server, where the change is already done. In the latest PR https://github.com/deepjavalibrary/djl/pull/2027, this model with the modification is automatically downloaded.

YRChen1998 commented 2 years ago

It works. Thanks a lot and look forward to more implements on the Gluon-TS extension.