jdb78 / pytorch-forecasting

Time series forecasting with PyTorch
https://pytorch-forecasting.readthedocs.io/
MIT License
3.75k stars 599 forks source link

Troubleshooting loss function; implementing focal loss #624

Open jsolson4 opened 2 years ago

jsolson4 commented 2 years ago

Greetings,

Jan, thank you so much for the package. It's amazing. I am working on implementing focal loss and it's the first loss function that I've designed so I am running into a few issues that are preventing me from implementing it.

How does one print components of a custom loss function (or other aspects of the package)? It's hard to know what's going on without the ability to print out and visualize objects as they're passed around. I've written both print() and premature return statements in attempt to display the components of the loss function (such as target and actual). At one point, I was able to see the objects in my display (after calling fit).

Is calling fit() and embedding print statements a good way to assess the loss function components? I was going to use dummy data, but the input is a MultiHorizonMetric. I was unsure how to create a dummy MultiHorizonMetric to play around with, so I started calling fit(). I'd appreciate any advice about using MultiHorizonMetric itself.

So here is my current implementation:

`class FocalLoss(MultiHorizonMetric): def loss(self, y_pred, target, alpha = 0.5, gamma = 2): p = torch.sigmoid(y_pred) BCE_loss = F.binary_cross_entropy_with_logits(y_pred, target, reduction='none') pt = torch.exp(-BCE_loss) # prevents nans when probability 0 F_loss = self.alpha * (1-pt)*self.gamma BCE_loss return F_loss.mean()

def to_prediction(self, y_pred: torch.Tensor) -> torch.Tensor: """Convert network prediction into a point prediction. Returns best label""" return y_pred.argmax(dim = -1)

def to_quantiles(self, y_pred: torch.Tensor, quantiles: List[float] = None) -> torch.Tensor: """ Convert network prediction into a quantile prediction.

Args:
  y_pred: prediction output of network
  quantiles (List[float], optional): quantiles for probability range. Defaults to quantiles as defined in the class initialization. 

Returns:
  torch.Tensor: prediction quantiles
"""

return y_pred

`

I am predicting on a horizon of 24 weeks. I receive an error:

ValueError: Target size (torch.Size([1, 24])) must be the same as input size (torch.Size([1, 24, 24]))

The error is triggered at this line in the loss function: BCE_loss = F.binary_cross_entropy_with_logits(y_pred, target, reduction='none')

So given this information, I am confused what the three dimensions of the input size tensor represent. From what I understand the first position is batch size and the third position is the output size. If that is correct, what does the second dimension represent? Is there an error in my set-up that's causing my to produce a 3D tensor when I should only have a 2D tensor?

jsolson4 commented 2 years ago

Update: here is my implementation. Would appreciate any feedback, it's my first custom loss function :).


import torch.nn.functional as torchF # normally just 'F' but I already used that identifier

class FocalLoss(MultiHorizonMetric):

  from typing import Dict, List, Tuple, Union
  from pytorch_forecasting.metrics import MultiHorizonMetric

  def __init__(self,  alpha, gamma):
    super(FocalLoss, self).__init__()
    self.alpha = alpha
    self.gamma = gamma

  def loss(self, y_pred, target):

    BCE_loss = torchF.cross_entropy(y_pred.view(-1, y_pred.size(-1)), target.view(-1), reduction="none").view(-1, target.size(-1))
    pt = torch.exp(-BCE_loss) # prevents nans when probability 0
    F_loss = self.alpha * (1-pt)**self.gamma * BCE_loss
    return F_loss.mean()

  def to_prediction(self, y_pred: torch.Tensor) -> torch.Tensor:
     """Convert network prediction into a point prediction. Returns best label"""
    return y_pred.argmax(dim = -1)

  def to_quantiles(self, y_pred: torch.Tensor, quantiles: List[float] = None) -> torch.Tensor:
    """
    Convert network prediction into a quantile prediction.
    Args:
      y_pred: prediction output of network
      quantiles (List[float], optional): quantiles for probability range. Defaults to quantiles as defined in the class initialization. 

    Returns:
      torch.Tensor: prediction quantiles
    """
    return y_pred