Hi, Thanks for the great repo. I am new to FSL and have met some problems and cannot solve them.
Problem
I am using the classical training notebooks and trying to implement the SimpleShot instead of Prototypical Networks. Not sure I fully understanding the code or not but I implemented the below and successfully predicted embeddings:
However, I got the error when I validating "RuntimeError: The size of tensor a (640) must match the size of tensor b (10) at non-singleton dimension 1".
Result:
Validation: 0%| | 0/500 [00:02<?, ?it/s]
---------------------------------------------------------------------------
RuntimeError Traceback (most recent call last)
Cell In[8], line 16
11 if epoch % validation_frequency == validation_frequency - 1:
12
13 # We use this very convenient method from EasyFSL's ResNet to specify
14 # that the model shouldn't use its last fully connected layer during validation.
15 model.set_use_fc(False)
---> 16 validation_accuracy = evaluate(
17 few_shot_classifier, val_loader, device=DEVICE, tqdm_prefix="Validation"
18 )
19 model.set_use_fc(True)
21 if validation_accuracy > best_validation_accuracy:
File [~\Documents\GitHub\easy-few-shot-learning-master\easyfsl\utils.py:157](http://localhost:8888/lab/tree/Documents/GitHub/easy-few-shot-learning-master/notebooks/~/Documents/GitHub/easy-few-shot-learning-master/easyfsl/utils.py#line=156), in evaluate(model, data_loader, device, use_tqdm, tqdm_prefix)
144 with tqdm(
145 enumerate(data_loader),
146 total=len(data_loader),
147 disable=not use_tqdm,
148 desc=tqdm_prefix,
149 ) as tqdm_eval:
150 for _, (
151 support_images,
152 support_labels,
(...)
155 _,
156 ) in tqdm_eval:
--> 157 correct, total = evaluate_on_one_task(
158 model,
159 support_images.to(device),
160 support_labels.to(device),
161 query_images.to(device),
162 query_labels.to(device),
163 )
165 outputs = model(query_images.to(device))
166 LOSS_FUNCTION = nn.CrossEntropyLoss()
File [~\Documents\GitHub\easy-few-shot-learning-master\easyfsl\utils.py:104](http://localhost:8888/lab/tree/Documents/GitHub/easy-few-shot-learning-master/notebooks/~/Documents/GitHub/easy-few-shot-learning-master/easyfsl/utils.py#line=103), in evaluate_on_one_task(model, support_images, support_labels, query_images, query_labels)
93 def evaluate_on_one_task(
94 model: FewShotClassifier,
95 support_images: Tensor,
(...)
98 query_labels: Tensor,
99 ) -> Tuple[int, int]:
100 """
101 Returns the number of correct predictions of query labels, and the total number of
102 predictions.
103 """
--> 104 model.process_support_set(support_images, support_labels)
105 predictions = model(query_images).detach().data
106 number_of_correct_predictions = int(
107 (torch.max(predictions, 1)[1] == query_labels).sum().item()
108 )
File [~\Documents\GitHub\easy-few-shot-learning-master\easyfsl\methods\few_shot_classifier.py:77](http://localhost:8888/lab/tree/Documents/GitHub/easy-few-shot-learning-master/notebooks/~/Documents/GitHub/easy-few-shot-learning-master/easyfsl/methods/few_shot_classifier.py#line=76), in FewShotClassifier.process_support_set(self, support_images, support_labels)
65 def process_support_set(
66 self,
67 support_images: Tensor,
68 support_labels: Tensor,
69 ):
70 """
71 Harness information from the support set, so that query labels can later be predicted using a forward call.
72 The default behaviour shared by most few-shot classifiers is to compute prototypes and store the support set.
(...)
75 support_labels: labels of support set images of shape (n_support, )
76 """
---> 77 self.compute_prototypes_and_store_support_set(support_images, support_labels)
File [~\Documents\GitHub\easy-few-shot-learning-master\easyfsl\methods\few_shot_classifier.py:148](http://localhost:8888/lab/tree/Documents/GitHub/easy-few-shot-learning-master/notebooks/~/Documents/GitHub/easy-few-shot-learning-master/easyfsl/methods/few_shot_classifier.py#line=147), in FewShotClassifier.compute_prototypes_and_store_support_set(self, support_images, support_labels)
141 """
142 Extract support features, compute prototypes, and store support labels, features, and prototypes.
143 Args:
144 support_images: images of the support set of shape (n_support, **image_shape)
145 support_labels: labels of support set images of shape (n_support, )
146 """
147 self.support_labels = support_labels
--> 148 self.support_features = self.compute_features(support_images)
149 self._raise_error_if_features_are_multi_dimensional(self.support_features)
150 self.prototypes = compute_prototypes(self.support_features, support_labels)
File [~\Documents\GitHub\easy-few-shot-learning-master\easyfsl\methods\few_shot_classifier.py:94](http://localhost:8888/lab/tree/Documents/GitHub/easy-few-shot-learning-master/notebooks/~/Documents/GitHub/easy-few-shot-learning-master/easyfsl/methods/few_shot_classifier.py#line=93), in FewShotClassifier.compute_features(self, images)
86 """
87 Compute features from images and perform centering and normalization.
88 Args:
(...)
91 features of shape (n_images, feature_dimension)
92 """
93 original_features = self.backbone(images)
---> 94 centered_features = original_features - self.feature_centering
95 if self.feature_normalization is not None:
96 return nn.functional.normalize(
97 centered_features, p=self.feature_normalization, dim=1
98 )
RuntimeError: The size of tensor a (640) must match the size of tensor b (10) at non-singleton dimension 1
How can we help
Can you help me to identify the issue? Thanks again
Hi, Thanks for the great repo. I am new to FSL and have met some problems and cannot solve them.
Problem I am using the classical training notebooks and trying to implement the SimpleShot instead of Prototypical Networks. Not sure I fully understanding the code or not but I implemented the below and successfully predicted embeddings:
Result: Predicting embeddings: 100%|██████████████████████████████████████████████████████████| 4/4 [00:03<00:00, 1.30batch/s]
However, I got the error when I validating "RuntimeError: The size of tensor a (640) must match the size of tensor b (10) at non-singleton dimension 1".
Result:
How can we help Can you help me to identify the issue? Thanks again