awslabs / gluonts

Probabilistic time series modeling in Python
https://ts.gluon.ai
Apache License 2.0
4.59k stars 750 forks source link

Rework `.predict`. #2299

Open jaheba opened 2 years ago

jaheba commented 2 years ago

Currently, .predict takes in a dataset in yields predictions for each input-timeseries.

However, this interface is a) not intuitive and b) does not reflect efficient invocation of networks in batches.

To predict a single time-series, one has to do something like:

time_series = ...

predictions = predictor.predict([time_series])
prediction = list(predictions)[0]

Ideally, one could just do:

prediction = predictor.predict(time_series)

In addition, we should have predict_batch which takes in a batch if input-time series to predict at one. We might want to check that the batch_size of the model corresponds with the size of the passed batch.

The current behaviour could be provided via .predict_all(dataset) or .predict_dataset(dataset) method.

jaheba commented 2 years ago

One problem that I see is that it's very hard to gracefully deprecate the current .predict API.

Instead we could have predict_one, predict_batch, predict_all and then think about how we would want to do the migration.

lostella commented 2 years ago

For batches, I was really thinking about only having the output be a stream of "batch forecasts" (see #2286), but keeping the input a Dataset just like .predict. The reason for this is, it's convenient to avoid "deconstructing" the output of a network into individual Forecast objects, since you can probably evaluate much faster if you don't; on the input side, the two possibilities are really a single data entry or a collection of them, which is a Dataset regardless of its size.

jaheba commented 2 years ago

I guess we can have both, predict_batch and predict_batches.

For example, we could use something like this:

def predict_batch(self, batch):
    with mx.Context(self.ctx):
        return self.prediction_net.forecast(batch)