HobbitLong / RepDistiller

[ICLR 2020] Contrastive Representation Distillation (CRD), and benchmark of recent knowledge distillation methods
BSD 2-Clause "Simplified" License
2.11k stars 389 forks source link

Questions about ContrastMemory #11

Closed HelloTobe closed 4 years ago

HelloTobe commented 4 years ago

@HobbitLong Thanks for your excellent work! I'm wondering why do you register two memories here in self.register_buffer('memory_v1', torch.rand(outputSize, inputSize).mul_(2 * stdv).add_(-stdv)) self.register_buffer('memory_v2', torch.rand(outputSize, inputSize).mul_(2 * stdv).add_(-stdv)), since only the student network is trained while the teacher is fixed.

HelloTobe commented 4 years ago

@HobbitLong Besides, I notice that the outputsize equals the length of the dataset in Cifar100 (Here is 50K). Is it the same in ImageNet which has more than 1M training images?

HobbitLong commented 4 years ago

Hi, @HelloTobe

I'm wondering why do you register two memories here in self.register_buffer('memoryv1', torch.rand(outputSize, inputSize).mul(2 stdv).add_(-stdv)) self.register_buffer('memoryv2', torch.rand(outputSize, inputSize).mul(2 stdv).add_(-stdv)),

A linear projection layer still needs to be learned during the training. This layer acts to align the feature dimension of student and teacher.

Besides, I notice that the outputsize equals the length of the dataset in Cifar100 (Here is 50K). Is it the same in ImageNet which has more than 1M training images?

That's related to NCE. Just let outputsize = len(dataset) should suffice. Therefore it's 1M in ImageNet.

HelloTobe commented 4 years ago

@HobbitLong Many thanks!

HelloTobe commented 4 years ago

Hi, @HobbitLong weight_v1 = torch.index_select(self.memory_v1, 0, idx.view(-1)).detach() weight_v1 = weight_v1.view(batchSize, K + 1, inputSize) out_v2 = torch.bmm(weight_v1, v2.view(batchSize, inputSize, 1))

According to the above code, why in the following memory_v1 is updated with v1 instead of v2? Since there is no relation between memory_v1 and v1.

l_pos = torch.index_select(self.memory_v1, 0, y.view(-1)) l_pos.mul_(momentum) l_pos.add_(torch.mul(v1, 1 - momentum))

HobbitLong commented 4 years ago

@HelloTobe , the point of CRD is contrasting between teacher and student, e.g., choose anchor from the teacher (v2), choose positives and negatives from student (memory_v1).