Lightning-AI / torchmetrics

Torchmetrics - Machine learning metrics for distributed, scalable PyTorch applications.
https://lightning.ai/docs/torchmetrics/
Apache License 2.0
2.01k stars 391 forks source link

Add `WeightedAbsolutePercentageError` #928

Closed Guan-t7 closed 2 years ago

Guan-t7 commented 2 years ago

🚀 Feature

Add Weighted Absolute Percentage Error (WAPE) metric. The description and formula can be found here: https://en.wikipedia.org/wiki/WMAPE

Motivation

WAPE is a common metric used in time series forecasting.

Alternatives

I don't think this metric can be implemented using arithmetics of existing metrics.

Additional context

A draft implementation is as follows. I'd like someone to take over the rest.

'''functional'''
from typing import Tuple

import torch
from torch import Tensor

from torchmetrics.utilities.checks import _check_same_shape

def _weighted_absolute_percentage_error_update(
    preds: Tensor,
    target: Tensor,
) -> Tuple[Tensor, int]:
    """Updates and returns variables required to compute Weighted Absolute Percentage Error. Checks for same
    shape of input tensors.

    Args:
        preds: Predicted tensor
        target: Ground truth tensor
        epsilon: Avoids ZeroDivisionError.
    """

    _check_same_shape(preds, target)

    sum_abs_error = (preds - target).abs().sum()
    sum_scale = target.abs().sum()

    return sum_abs_error, sum_scale

def _weighted_absolute_percentage_error_compute(sum_abs_error: Tensor, sum_scale: Tensor, epsilon: float = 1.17e-06,) -> Tensor:
    """Computes Weighted Absolute Percentage Error.

    Args:
        num_obs: Number of predictions or observations
    """

    return sum_abs_error / torch.clamp(sum_scale, min=epsilon)

def weighted_absolute_percentage_error(preds: Tensor, target: Tensor) -> Tensor:
    r"""
    Computes weighted absolute percentage error (WAPE_):

    Where :math:`y` is a tensor of target values, and :math:`\hat{y}` is a tensor of predictions.

    Args:
        preds: estimated labels
        target: ground truth labels

    Return:
        Tensor with WAPE.
    """
    sum_abs_error, sum_scale = _weighted_absolute_percentage_error_update(
        preds,
        target,
    )
    weighted_ape = _weighted_absolute_percentage_error_compute(
        sum_abs_error,
        sum_scale,
    )

    return weighted_ape

'''module'''
from typing import Any, Callable, Optional

import torch
from torch import Tensor, tensor

# todo functional import
from torchmetrics.metric import Metric

class WeightedAbsolutePercentageError(Metric):
    r"""
    Computes weighted absolute percentage error (`WAPE`_).

    Where :math:`y` is a tensor of target values, and :math:`\hat{y}` is a tensor of predictions.

    Args:
        compute_on_step:
            Forward only calls ``update()`` and return None if this is set to False.
        dist_sync_on_step:
            Synchronize metric state across processes at each ``forward()`` before returning the value at the step.
        process_group:
            Specify the process group on which synchronization is called.

    Note:
        WAPE output is a non-negative floating point. Best result is 0.0 .
    """
    is_differentiable = True
    higher_is_better = False
    sum_abs_error: Tensor
    sum_scale: Tensor

    def __init__(
        self,
        compute_on_step: bool = True,
        dist_sync_on_step: bool = False,
        process_group: Optional[Any] = None,
        dist_sync_fn: Callable = None,
    ) -> None:
        super().__init__(
            compute_on_step=compute_on_step,
            dist_sync_on_step=dist_sync_on_step,
            process_group=process_group,
            dist_sync_fn=dist_sync_fn,
        )

        self.add_state("sum_abs_error", default=tensor(0.0), dist_reduce_fx="sum")
        self.add_state("sum_scale", default=tensor(0.0), dist_reduce_fx="sum")

    def update(self, preds: Tensor, target: Tensor) -> None:  # type: ignore
        """Update state with predictions and targets.

        Args:
            preds: Predictions from model
            target: Ground truth values
        """
        sum_abs_error, sum_scale = _weighted_absolute_percentage_error_update(preds, target)

        self.sum_abs_error += sum_abs_error
        self.sum_scale += sum_scale

    def compute(self) -> Tensor:
        """Computes weighted absolute percentage error over state."""
        return _weighted_absolute_percentage_error_compute(self.sum_abs_error, self.sum_scale)
github-actions[bot] commented 2 years ago

Hi! thanks for your contribution!, great first issue!