PyTorch code for "Prototypical Contrastive Learning of Unsupervised Representations"
kroegern1 commented 8 months ago

Hello, I'm trying to apply this PCL technique to CIFAR-10 since I don't directly have widespread gpu access. After modifying it to run on the cpu and mps for pytorch, I ran into a few problems. One that I fundamentally am having trouble with is in the training.

  1. image The pseudo-code showing the EM updates doesn't reflect the code. I see in the code:

for epoch in start_epoch to epochs:
    if epoch >= warmup_epoch:
        features = compute_features(data_loader, model)
        normalize features with L2-norm > 1.5
        cluster_result = run_kmeans(features)

    for batch in train_loader:
        load data
        compute model outputs and targets for both instance and prototype learning
        calculate InfoNCE loss and, if applicable, ProtoNCE loss
        perform backpropagation

I don't see in this code how there's an E step AND THEN the M step. Why is there a mismatch here, which is correct?

  1. Assuming the code is correct, why do we calculate cluster_result on the eval_dataset (10,000 samples for cifar). See "eval_loader" and "len(eval_dataset)" being the culprits.
if epoch >= warmup_epoch:
        # compute momentum features for center-cropped images
        features = compute_features(eval_loader, model, low_dim, device)

        # placeholder for clustering result
        cluster_result = {'im2cluster':[],'centroids':[],'density':[]}
        for num_cluster in num_clusters:  # Assuming num_clusters is an iterable of desired cluster counts
            cluster_result['im2cluster'].append(torch.zeros(len(eval_dataset), dtype=torch.long))

However, when we call train a few lines later, we pass in train_loader (which has a 50,000 samples, a different number of samples than eval_loader, 10,000 samples) and cluster_result (which holds 10,000 samples).

train(train_loader, model, criterion, optimizer, epoch, device, cluster_result)

Therefore, there's a shape mismatch in the forward function when we do

    111 for n, (im2cluster, prototypes, density) in enumerate(zip(cluster_result['im2cluster'], cluster_result['centroids'], cluster_result['density'])):
    112     # get positive prototypes
--> 113     pos_proto_id = im2cluster[index]

Thus leading to this error:

IndexError Traceback (most recent call last) Cell In[77], line 28 24 adjust_learning_rate(optimizer, epoch, lr) 26 # train for one epoch 27 # print("device", device) ---> 28 losses = train(train_loader, model, criterion, optimizer, epoch, device, cluster_result) 30 print('Epoch: [{0}]\t'.format(epoch, loss=losses)) 32 if (epoch+1)%5==0:

Cell In[76], line 44, in train(train_loader, model, criterion, optimizer, epoch, device, cluster_result) 34 # visualize the images 35 # plt.figure(figsize=(6, 3)) # Adjust the size as necessary 36 # plt.subplot(1, 2, 1) (...) 41 # compute output 42 ### HHEEEREEE 43 print("min",min(index), "max", max(index)) ---> 44 output, target, output_proto, target_proto = model(im_q=images[0], im_k=images[1], cluster_result=cluster_result, index=index) 45 46 # InfoNCE loss

File ~/anaconda3/envs/x/lib/python3.10/site-packages/torch/nn/modules/, in Module._wrapped_call_impl(self, *args, kwargs) 1516 return self._compiled_call_impl(*args, *kwargs) # type: ignore[misc] 1517 else: -> 1518 return self._call_impl(args, kwargs)

File ~/anaconda3/envs/x/lib/python3.10/site-packages/torch/nn/modules/, in Module._call_impl(self, *args, *kwargs) 1522 # If we don't have any hooks, we want to skip the rest of the logic in 1523 # this function, and just call forward. 1524 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks 1525 or _global_backward_pre_hooks or _global_backward_hooks 1526 or _global_forward_hooks or _global_forward_pre_hooks): -> 1527 return forward_call(args, **kwargs) 1529 try: 1530 result = None

Cell In[62], line 113, in MoCo.forward(self, im_q, im_k, is_eval, cluster_result, index) 110 proto_logits = [] 111 for n, (im2cluster, prototypes, density) in enumerate(zip(cluster_result['im2cluster'], cluster_result['centroids'], cluster_result['density'])): 112 # get positive prototypes --> 113 pos_proto_id = im2cluster[index] 114 pos_prototypes = prototypes[pos_proto_id]
116 # sample negative prototypes

IndexError: index 18843 is out of bounds for dimension 0 with size 10000

Any ideas how index works/what cluster_result should be holding?