salesforce / CoST

PyTorch code for CoST: Contrastive Learning of Disentangled Seasonal-Trend Representations for Time Series Forecasting (ICLR 2022)
BSD 3-Clause "New" or "Revised" License
212 stars 43 forks source link

Traing Loss problem. #25

Open 740402059 opened 8 months ago

740402059 commented 8 months ago

When I used your algorithm and parameters to train on both the WTH dataset and my own dataset, I found that the loss was very low in the first epoch, but increased sharply in the second epoch, and subsequently, the loss remained higher than in the first epoch. The variation in the training loss is perplexing, and I hope you can provide some insights.

Wentao-Gao commented 7 months ago

I think it is the problem of loss function, the loss function of Time Domain Contrastive Loss using the MOCO loss, but add a denominator

截屏2023-11-19 下午1 23 00

I assume it is not necessary. if you delete this one. you may solve the problem.

And the loss code in CoST/cost.py Class CoSTModel I write is:


def compute_loss(self, q, k, k_negs):
    # compute logits
    # positive logits: Nx1
    l_pos = torch.einsum('nc,nc->n', [q, k]).unsqueeze(-1)
    # negative logits: NxK
    l_neg = torch.einsum('nc,ck->nk', [q, k_negs])

    # logits: Nx(1+K)
    logits = torch.cat([l_pos, l_neg], dim=1)

    # apply temperature
    logits /= self.T

    # labels: positive key indicators - first dim of each batch
    labels = torch.zeros(logits.shape[0], dtype=torch.long).cuda()

    # Mask to zero-out positives in the denominator
    mask = torch.ones_like(logits)
    mask[:, 0] = 0

    # Apply mask and calculate cross-entropy loss
    logits_masked = logits - (mask * 1e9)  # Using a large value to mask
    loss = F.cross_entropy(logits_masked, labels)

    return loss

Hope this would help.

740402059 commented 7 months ago

Thank you. I made the modification according to your suggestion, and this problem has been solved.

Wentao-Gao commented 7 months ago

Thank you. I made the modification according to your suggestion, and this problem has been solved.

By the way, the moco v2 is using this paper's loss function. I just noticed that.