salesforce / CoST

PyTorch code for CoST: Contrastive Learning of Disentangled Seasonal-Trend Representations for Time Series Forecasting (ICLR 2022)
BSD 3-Clause "New" or "Revised" License
215 stars 43 forks source link

compute loss (labels: torch.zeros) #11

Closed shlee-home closed 1 year ago

shlee-home commented 2 years ago

Hello, again. I'm studying your paper and code. However, in following codes in your 'cost.py' file,

l_pos = torch.einsum('nc,nc->n', [a1, a2]).unsqueeze(-1)
# negative logits: NxK
l_neg = torch.einsum('nc,ck->nk', [a1, a2_neg])

# logits: Nx(1+K)
logits = torch.cat([l_pos, l_neg], dim=1)

# apply temperature
logits /= T

# labels: positive key indicators - first dim of each batch
labels = torch.zeros(logits.shape[0], dtype=torch.long).cuda()
loss = F.cross_entropy(logits, labels)

I think that one of the instances (N) is 1.. because when crossentropy is calculated, the positive one's label become 1. I don't know well, so I want your advice. Thank you.

gorold commented 2 years ago

Hi, thanks for your interest in our work. The labels are set as torch.zeros(logits.shape[0]) because the function F.cross_entropy takes as input in the second argument, the label indices. We let the 0-th index be the positive label, thus the use of torch.zeros. You can refer to the function documentation here.