xuguodong03 / SSKD

[ECCV2020] Knowledge Distillation Meets Self-Supervision
233 stars 47 forks source link

Contrastive Prediction #6

Open larry10hhobh opened 4 years ago

larry10hhobh commented 4 years ago

Hi

Thank U for your code. I find a question in code of contrastive prediction. In student.py

# train ssp_head
for epoch in range(args.t_epoch):

    t_model.eval()
    loss_record = AverageMeter()
    acc_record = AverageMeter()

    start = time.time()
    for x, _ in train_loader:

        t_optimizer.zero_grad()

        x = x.cuda()
        c,h,w = x.size()[-3:]
        x = x.view(-1, c, h, w)

        _, rep, feat = t_model(x, bb_grad=False)
        batch = int(x.size(0) / 4)
        nor_index = (torch.arange(4*batch) % 4 == 0).cuda()
        aug_index = (torch.arange(4*batch) % 4 != 0).cuda()

        nor_rep = rep[nor_index]
        aug_rep = rep[aug_index]
        nor_rep = nor_rep.unsqueeze(2).expand(-1,-1,3*batch).transpose(0,2)
        aug_rep = aug_rep.unsqueeze(2).expand(-1,-1,1*batch)
        simi = F.cosine_similarity(aug_rep, nor_rep, dim=1)
        target = torch.arange(batch).unsqueeze(1).expand(-1,3).contiguous().view(-1).long().cuda()
        loss = F.cross_entropy(simi, target)

I think nor_rep and aug_rep come from different samples. It is not the relation between X and its transformation mentioned in the paper. Is my understanding wrong?

xuguodong03 commented 4 years ago

Hi, thanks for running this repo.

The batch from train_loader is 64x4x3x32x32. The dimension '4' means one normal data + three transformed data. After x.view(), its shape is (64x4)x3x32x32. Suppose the output feature shape is (64x4)xF. nor_index and aug_index split the output features into two tensors: 64xF (normal) and 192xF (transformed). These two tensors are corresponding to nor_rep and aug_rep.

larry10hhobh commented 4 years ago

你好,我的疑惑主要是这个4是怎么来的。 pytorch进行是在线增广,这样1个epoch里面应该不会同时出现1个样本及其变换吧。即便存在的话,为什么确定是一个原本加上增广的3个样本,这个1+3是怎么来的?

xuguodong03 commented 4 years ago

没有使用torchvision.datasets.CIFAR100,而是对dataset进行了修改,参见cifar.py

larry10hhobh commented 4 years ago

好的,感谢解答~