Open jaheba opened 2 years ago
One problem that I see is that it's very hard to gracefully deprecate the current .predict
API.
Instead we could have predict_one
, predict_batch
, predict_all
and then think about how we would want to do the migration.
For batches, I was really thinking about only having the output be a stream of "batch forecasts" (see #2286), but keeping the input a Dataset
just like .predict
. The reason for this is, it's convenient to avoid "deconstructing" the output of a network into individual Forecast
objects, since you can probably evaluate much faster if you don't; on the input side, the two possibilities are really a single data entry or a collection of them, which is a Dataset
regardless of its size.
I guess we can have both, predict_batch
and predict_batches
.
For example, we could use something like this:
def predict_batch(self, batch):
with mx.Context(self.ctx):
return self.prediction_net.forecast(batch)
Currently,
.predict
takes in a dataset in yields predictions for each input-timeseries.However, this interface is a) not intuitive and b) does not reflect efficient invocation of networks in batches.
To predict a single time-series, one has to do something like:
Ideally, one could just do:
In addition, we should have
predict_batch
which takes in a batch if input-time series to predict at one. We might want to check that thebatch_size
of the model corresponds with the size of the passed batch.The current behaviour could be provided via
.predict_all(dataset)
or.predict_dataset(dataset)
method.