slaypni / fastdtw

A Python implementation of FastDTW
MIT License
774 stars 122 forks source link

Creating a Loss module in PyTorch with FastDTW #63

Open rllyryan opened 8 months ago

rllyryan commented 8 months ago

Hi repository owner(s)!

I greatly appreciate your work, and I strongly believe that this implementation can help regression models from becoming similar to persistence models during training to become a forecasting utility.

I need some help in creating a loss module with the help of your library, here is my current design (simple):

class CustomLoss(nn.Module):
    def __init__(self):
        super(CustomLoss, self).__init__()

    def forward(self, predicted, target):
        # Calculate the dynamic time warping distance using FASTDTW
        distance, _ = fastdtw(predicted.detach().numpy(), target.detach().numpy(), dist=euclidean)

        # Convert the distance to a PyTorch tensor
        distance = torch.tensor(distance, dtype=torch.float32, requires_grad=True)

        # Return the distance as the loss
        return distance

criterion = CustomLoss()

Could I ask if this is the way to integrate your module into a custom loss module?

Thank you!