salesforce / CoST

PyTorch code for CoST: Contrastive Learning of Disentangled Seasonal-Trend Representations for Time Series Forecasting (ICLR 2022)
BSD 3-Clause "New" or "Revised" License
215 stars 43 forks source link

The use of instance_contrastive_loss #5

Closed chenxiaodanhit closed 2 years ago

chenxiaodanhit commented 2 years ago

Thanks for your great work! Could you please explain the instance contrastive loss as written in cost.py?

logits = torch.tril(sim, diagonal=-1)[:, :, :-1] # T x 2B x (2B-1) logits += torch.triu(sim, diagonal=1)[:, :, 1:] logits = -F.log_softmax(logits, dim=-1) i = torch.arange(B, device=z1.device) loss = (logits[:, i, B + i - 1].mean() + logits[:, B + i, i].mean()) / 2

The above defination is different from the L{amp} and L{phase} noted in the paper.

Look forward to your reply! Thanks in advance!

gorold commented 2 years ago

Hey @chenxiaodanhit, thanks for your interest in our work. The above implementation is equivalent to L{amp} and L{phase} as defined in the paper. To explain, the torch.tril and torch.triu lines ensure that the similarity matrix when i == j is omitted, and logits calculates the log softmax term for all indices, but we have to perform selection on the appropriate indices, as performed by the last line of code. Hope this helps.

chenxiaodanhit commented 2 years ago

Thanks for your reply!