jdb78 / pytorch-forecasting

Time series forecasting with PyTorch
https://pytorch-forecasting.readthedocs.io/
MIT License
3.74k stars 599 forks source link

Combined loss function using multiple targets #1098

Open serhan-gul opened 1 year ago

serhan-gul commented 1 year ago

I have a multi-target regression problem where I currently use MAE loss for each of the targets.

model = LitFCMultiTargetModel.from_dataset(
            train_dataset,
            hidden_size=self.hidden_size,
            n_hidden_layers=self.n_hidden_layers,
            log_interval=1,
            loss=MultiLoss(metrics=[MAE()] * len(self.targets)),
)

However, the targets are related to each other geometrically, i.e. they are x, y, z coordinates measured from a sensor. Therefore, I want to use the Euclidean distance between the predicted and target values as loss function instead of measuring the loss for each coordinate separately using MAE. Is there a way to create such a combined loss function?

zhengkd95 commented 1 year ago

It is really difficult to implement combined loss in the current framework, if your train_dataset is defined with multiple targets. The models in this package would automatically change the original loss function to a MultiLoss class. I also encountered the same problem. If we want to use a single model to forecasting multi-target multi-horizon time series, there are two major issues that need to be addressed:

This is not a simple problem as a lot of changes are required at very basic level. Could you please look into this issue if you have time @jdb78 ?

I have tried to implement MultiTargetMultiHorizonMetric. However, as this package is very sophisticated, I think my code would only work for myself.

aeklant commented 1 year ago

I am running into a similar difficulty; in my case it's due to the business application for the time series predictions.

I have two targets that are related to each other through certain business logic which determines the bottom line (i.e. profit). If I train using a separate loss function for each, I end up in situations where I can have a model with poor loss metrics that severely outperforms another model with better metrics (in terms of business application, or profitability). I think this could be easily solved, if I was able to implement that business logic into the loss function itself.

I assume this use case would greatly benefit the adoption of pytorch_forecasting in the realm of business.

@zhengkd95 would you mind sharing your general implementation? I am somewhat new to pytorch_forecasting so any help would be beneficial.

@jdb78 do you have any pointers on how to correctly implement this? I wouldn't mind taking a crack at it but I am not sure what the general solution should look like. So far I can see that the step function in the model has access to the y values in the form of a list, so I am thinking of concatenating them into a single tensor and creating a new metric (which I think is what @zhengkd95 described), but I can't tell if this would lead to undesired side effects down the line. What are your thoughts?

zhengkd95 commented 1 year ago

@aeklant For some reason I gave up the implementation. I am also new to this package, and its way too sophisitcated for me to crack the models. But I can show you some of the loss function:

from pytorch_forecasting.metrics import MultiHorizonMetric
class MultiTargetMultiHorizonMetric(MultiHorizonMetric):
    def update(self, y_pred, target):
        """
        Update method of metric that handles masking of values.

        Do not override this method but :py:meth:`~loss` instead

        Args:
            y_pred (Dict[str, torch.Tensor]): network output
            target (Union[torch.Tensor, rnn.PackedSequence]): actual values

        Returns:
            torch.Tensor: loss as a single number for backpropagation
        """
        # unpack weight
        if isinstance(target, (list, tuple)) and not isinstance(target, rnn.PackedSequence):
            target, weight = target
        else:
            weight = None

        # unpack target
        if isinstance(target, rnn.PackedSequence):
            target, lengths = unpack_sequence(target)
        else:
            if isinstance(target, list):
                target_stack = torch.stack(target)
                lengths = torch.full((target_stack.size(0),), fill_value=target_stack.size(1)*target_stack.size(2), dtype=torch.long, device=target_stack.device)
                # lengths = [torch.full((target_i.size(0),), fill_value=target_i.size(1), dtype=torch.long, device=target_i.device) for target_i in target]
            else:
                # legnth : (batch_size, ) 
                lengths = torch.full((target.size(0),), fill_value=target.size(1), dtype=torch.long, device=target.device)

        losses = self.loss(y_pred, target)
        # weight samples
        if weight is not None:
            losses = losses * unsqueeze_like(weight, losses)
        self._update_losses_and_lengths(losses, lengths)

    def to_prediction(self, y_pred: torch.Tensor) -> torch.Tensor:
        """
        Convert network prediction into a point prediction.

        Args:
            y_pred: prediction output of network

        Returns:
            torch.Tensor: point prediction
        """
        if isinstance(y_pred, list):
            return torch.stack(y_pred)
            # return [self.to_prediction(i) for i in y_pred]
        if y_pred.ndim == 3:
            if self.quantiles is None:
                assert y_pred.size(-1) == 1, "Prediction should only have one extra dimension"
                y_pred = y_pred[..., 0]
            else:
                y_pred = y_pred.mean(-1)
        return y_pred

    def mask_losses(self, losses: torch.Tensor, lengths: torch.Tensor, reduction: str = None) -> torch.Tensor:
        """
        Mask losses.

        Args:
            losses (torch.Tensor): total loss. first dimenion are samples, second timesteps
            lengths (torch.Tensor): total length
            reduction (str, optional): type of reduction. Defaults to ``self.reduction``.

        Returns:
            torch.Tensor: masked losses
        """
        if reduction is None:
            reduction = self.reduction
        if losses.ndim >= 3:
            # do not mask at all
            return losses
        # if isinstance(losses,list):
            # do not mask at all
            # return [self.mask_losses(losses_i,lengths_i) for losses_i, lengths_i in zip(losses,lengths)]
        if losses.ndim > 0:
            # mask loss
            mask = torch.arange(losses.size(1), device=losses.device).unsqueeze(0) >= lengths.unsqueeze(-1)
            if losses.ndim > 2:
                mask = mask.unsqueeze(-1)
                dim_normalizer = losses.size(-1)
            else:
                dim_normalizer = 1.0
            # reduce to one number
            if reduction == "none":
                losses = losses.masked_fill(mask, float("nan"))
            else:
                losses = losses.masked_fill(mask, 0.0) / dim_normalizer
        return losses

    def to_quantiles(self, y_pred: torch.Tensor, quantiles: List[float] = None) -> torch.Tensor:
        """
        Convert network prediction into a quantile prediction.

        Args:
            y_pred: prediction output of network
            quantiles (List[float], optional): quantiles for probability range. Defaults to quantiles as
                as defined in the class initialization.

        Returns:
            torch.Tensor: prediction quantiles
        """
        if quantiles is None:
            quantiles = self.quantiles
        if isinstance(y_pred, List):
            results = []
            for y_pred_i in y_pred:
                results.append(self.to_quantiles(y_pred_i,quantiles))
            return results
        else:
            if y_pred.ndim == 2:
                return y_pred.unsqueeze(-1)
            elif y_pred.ndim == 3:
                if y_pred.size(2) > 1:  # single dimension means all quantiles are the same
                    assert quantiles is not None, "quantiles are not defined"
                    y_pred = torch.quantile(y_pred, torch.tensor(quantiles, device=y_pred.device), dim=2).permute(1, 2, 0)
                return y_pred
            else:
                raise ValueError(f"prediction has 1 or more than 3 dimensions: {y_pred.ndim}")

    def _update_losses_and_lengths(self, losses: torch.Tensor, lengths: torch.Tensor):
        losses = self.mask_losses(losses, lengths)
        if self.reduction == "none":
            if isinstance(losses, list):
                for i in range(len(losses)):
                    if self.losses[i].ndim == 0:
                        self.losses[i] = losses
                        self.lengths[i] = lengths
                    else:
                        self.losses[i] = torch.cat([self.losses[i], losses[i]], dim=0)
                        self.lengths[i] = torch.cat([self.lengths[i], lengths[i]], dim=0)
            else:
                if self.losses.ndim == 0:
                    self.losses = losses
                    self.lengths = lengths
                else:
                    self.losses = torch.cat([self.losses, losses], dim=0)
                    self.lengths = torch.cat([self.lengths, lengths], dim=0)
        else:
            if isinstance(losses, list):
                for i in range(len(losses)):
                    losses[i] = losses[i].sum()
                    if not torch.isfinite(losses[i]):
                        losses[i] = torch.tensor(1e9, device=losses[i].device)
                        warnings.warn("Loss is not finite. Resetting it to 1e9")
                    self.losses[i] = self.losses[i] + losses[i]
                    self.lengths[i] = self.lengths[i] + lengths[i].sum()
            else:
                losses = losses.sum()
                if not torch.isfinite(losses):
                    losses = torch.tensor(1e9, device=losses.device)
                    warnings.warn("Loss is not finite. Resetting it to 1e9")
                self.losses = self.losses + losses
                self.lengths = self.lengths + lengths.sum()
    def rescale_parameters(
        self, parameters: torch.Tensor, target_scale: torch.Tensor, encoder: BaseEstimator
    ) -> torch.Tensor:
        """
        Rescale normalized parameters into the scale required for the output.

        Args:
            parameters (torch.Tensor): normalized parameters (indexed by last dimension)
            target_scale (torch.Tensor): scale of parameters (n_batch_samples x (center, scale))
            encoder (BaseEstimator): original encoder that normalized the target in the first place

        Returns:
            torch.Tensor: parameters in real/not normalized space
        """
        result = [encoder[i](dict(prediction=parameters[:,:,i],target_scale=target_scale[i])) for i in range(len(encoder))]
        return result

        # return encoder(dict(prediction=parameters, target_scale=target_scale))

class my_MAE(MultiTargetMultiHorizonMetric):
    def loss(self, y_pred, target):
        if isinstance(target, list) or isinstance(y_pred, list):
            try:
                target_stack = torch.stack(target).squeeze()
            except TypeError:
                target_stack = target
            try:
                y_pred_stack = torch.stack(y_pred).squeeze()
            except TypeError:
                y_pred_stack = y_pred
            loss = (y_pred_stack - target_stack).abs()
            # loss = [(self.to_prediction(y_pred_i) - target_i).abs() for y_pred_i, target_i in zip(y_pred,target) ]
        else:
            loss = (self.to_prediction(y_pred) - target).abs()
        return loss

Here the my_MAE class calculate the combined loss of different targets.