awslabs / gluonts

Probabilistic time series modeling in Python
https://ts.gluon.ai
Apache License 2.0
4.56k stars 748 forks source link

Factoring loss out of distribution objects #1482

Open lostella opened 3 years ago

lostella commented 3 years ago

Problem: It is not clear how one can select a different loss to train a model, other than by customizing the network's code.

(Some proposal/discussion on this problem is also in #1043.)

Models (intended as networks) in gluonts can in principle have all sorts of different outputs: sample paths, marginal samples, parametric distributions (intended as distribution objects, or the tuple of tensors containing their parameters), point forecasts, quantiles... None of these is intrinsically tied to a "loss", despite the fact that e.g. by default we compute negative log-likelihooSo it might make sense to isolate the concept of "loss" as "something that compares (some form of) prediction to ground truth"

class Loss:
    def __call__(self, prediction, ground_truth):
        [...]

(Type annotations are omitted on purpose). For example, for a gluonts Distribution some losses could be

class NegativeLogLikelihood(Loss):
    def __call__(self, prediction: gluonts.mx.distribution.Distribution, ground_truth: np.ndarray):
        return -prediction.log_prob(ground_truth)
        # one can also do any (weighted) averaging across the time dimension here

class ContinuousRankedProbabilityScore(Loss):
    def __call__(self, prediction: gluonts.mx.distribution.Distribution, ground_truth: np.ndarray):
        return prediction.crps(ground_truth)
        # one can also do any (weighted) averaging across the time dimension here

And defining more would be extremely simple. These objects could be injected in any of the model training components (estimator, network, trainer, depending on who's doing the loss computation), and one can put a custom Loss object of a similar nature as the default one (i.e. with same type for the prediction argument.

jaheba commented 3 years ago

Couldn't we get away with just using functions? Or when would you initialise a loos with some custom parameters?

jaheba commented 3 years ago

Also, I guess the losses would be framework specific, i.e. we would have different implementations for mxnet and torch.

lostella commented 3 years ago

Couldn't we get away with just using functions? Or when would you initialise a loos with some custom parameters?

Of course, one should be able to use just a function for that. When the loss has parameters, one can do a class as above or fix the parameters with partial. To some extent it is a matter of style, but if you need to serialize it then a class with @validated constructor might be handy

jaheba commented 3 years ago

I think we really want to stop using validated everywhere. In many cases you should get away with a simple pydantic model.

lostella commented 3 years ago

Yes, whatever works

baharian commented 7 months ago

@lostella Hi! Apologies for reviving a zombie thread, but is there any update on this issue?