SimpleShot "The size of tensor a (640) must match the size of tensor b (10) at non-singleton dimension 1" #151

Open kimchiTuna opened 4 months ago

kimchiTuna commented 4 months ago

Hi, Thanks for the great repo. I am new to FSL and have met some problems and cannot solve them.

Problem I am using the classical training notebooks and trying to implement the SimpleShot instead of Prototypical Networks. Not sure I fully understanding the code or not but I implemented the below and successfully predicted embeddings:

feature_centering = compute_average_features_from_images(train_loader, model, DEVICE)
few_shot_classifier = SimpleShot(model,feature_centering=feature_centering,feature_normalization=2.0,use_softmax=True).to(DEVICE)

Result: Predicting embeddings: 100%|██████████████████████████████████████████████████████████| 4/4 [00:03<00:00, 1.30batch/s]

However, I got the error when I validating "RuntimeError: The size of tensor a (640) must match the size of tensor b (10) at non-singleton dimension 1".


Validation:   0%|                                                                              | 0/500 [00:02<?, ?it/s]
RuntimeError                              Traceback (most recent call last)
Cell In[8], line 16
     11 if epoch % validation_frequency == validation_frequency - 1:
     13     # We use this very convenient method from EasyFSL's ResNet to specify
     14     # that the model shouldn't use its last fully connected layer during validation.
     15     model.set_use_fc(False)
---> 16     validation_accuracy = evaluate(
     17         few_shot_classifier, val_loader, device=DEVICE, tqdm_prefix="Validation"
     18     )
     19     model.set_use_fc(True)
     21     if validation_accuracy > best_validation_accuracy:

File [~\Documents\GitHub\easy-few-shot-learning-master\easyfsl\](http://localhost:8888/lab/tree/Documents/GitHub/easy-few-shot-learning-master/notebooks/~/Documents/GitHub/easy-few-shot-learning-master/easyfsl/, in evaluate(model, data_loader, device, use_tqdm, tqdm_prefix)
    144 with tqdm(
    145     enumerate(data_loader),
    146     total=len(data_loader),
    147     disable=not use_tqdm,
    148     desc=tqdm_prefix,
    149 ) as tqdm_eval:
    150     for _, (
    151         support_images,
    152         support_labels,
    155         _,
    156     ) in tqdm_eval:
--> 157         correct, total = evaluate_on_one_task(
    158             model,
    159   ,
    160   ,
    161   ,
    162   ,
    163         )
    165         outputs = model(
    166         LOSS_FUNCTION = nn.CrossEntropyLoss()

File [~\Documents\GitHub\easy-few-shot-learning-master\easyfsl\](http://localhost:8888/lab/tree/Documents/GitHub/easy-few-shot-learning-master/notebooks/~/Documents/GitHub/easy-few-shot-learning-master/easyfsl/, in evaluate_on_one_task(model, support_images, support_labels, query_images, query_labels)
     93 def evaluate_on_one_task(
     94     model: FewShotClassifier,
     95     support_images: Tensor,
     98     query_labels: Tensor,
     99 ) -> Tuple[int, int]:
    100     """
    101     Returns the number of correct predictions of query labels, and the total number of
    102     predictions.
    103     """
--> 104     model.process_support_set(support_images, support_labels)
    105     predictions = model(query_images).detach().data
    106     number_of_correct_predictions = int(
    107         (torch.max(predictions, 1)[1] == query_labels).sum().item()
    108     )

File [~\Documents\GitHub\easy-few-shot-learning-master\easyfsl\methods\](http://localhost:8888/lab/tree/Documents/GitHub/easy-few-shot-learning-master/notebooks/~/Documents/GitHub/easy-few-shot-learning-master/easyfsl/methods/, in FewShotClassifier.process_support_set(self, support_images, support_labels)
     65 def process_support_set(
     66     self,
     67     support_images: Tensor,
     68     support_labels: Tensor,
     69 ):
     70     """
     71     Harness information from the support set, so that query labels can later be predicted using a forward call.
     72     The default behaviour shared by most few-shot classifiers is to compute prototypes and store the support set.
     75         support_labels: labels of support set images of shape (n_support, )
     76     """
---> 77     self.compute_prototypes_and_store_support_set(support_images, support_labels)

File [~\Documents\GitHub\easy-few-shot-learning-master\easyfsl\methods\](http://localhost:8888/lab/tree/Documents/GitHub/easy-few-shot-learning-master/notebooks/~/Documents/GitHub/easy-few-shot-learning-master/easyfsl/methods/, in FewShotClassifier.compute_prototypes_and_store_support_set(self, support_images, support_labels)
    141 """
    142 Extract support features, compute prototypes, and store support labels, features, and prototypes.
    143 Args:
    144     support_images: images of the support set of shape (n_support, **image_shape)
    145     support_labels: labels of support set images of shape (n_support, )
    146 """
    147 self.support_labels = support_labels
--> 148 self.support_features = self.compute_features(support_images)
    149 self._raise_error_if_features_are_multi_dimensional(self.support_features)
    150 self.prototypes = compute_prototypes(self.support_features, support_labels)

File [~\Documents\GitHub\easy-few-shot-learning-master\easyfsl\methods\](http://localhost:8888/lab/tree/Documents/GitHub/easy-few-shot-learning-master/notebooks/~/Documents/GitHub/easy-few-shot-learning-master/easyfsl/methods/, in FewShotClassifier.compute_features(self, images)
     86 """
     87 Compute features from images and perform centering and normalization.
     88 Args:
     91     features of shape (n_images, feature_dimension)
     92 """
     93 original_features = self.backbone(images)
---> 94 centered_features = original_features - self.feature_centering
     95 if self.feature_normalization is not None:
     96     return nn.functional.normalize(
     97         centered_features, p=self.feature_normalization, dim=1
     98     )

RuntimeError: The size of tensor a (640) must match the size of tensor b (10) at non-singleton dimension 1

How can we help Can you help me to identify the issue? Thanks again

kimchiTuna commented 4 months ago

I solved the issue when I turned off the FC before computing the feature centering. Is this the correct way to do it?