Closed wongsihan closed 4 years ago
Hi,
The main reason here is that the hyper-parameters are not right. (The default parameters uses meta-training shot = 3 and meta-val shot = 1), so that the get_proto function could not reshape the tensor right.
I have fixed this bug (and also change one parameter in get_proto)
Please also note that the "SemiProtoFEAT" implements a transductive learning method. Please use FEAT for inductive few-shot learning.
Thank you for your reply!
Hello, I am a new student in deep learning. I have a simple problem that I cannot solve at present. I would like to ask for your advice。 I download the dataset CUB and miniimagenet, put it in the right palce. When I run train_fsl.py, the program reported an error.
100%|█████████████████| 38400/38400 [00:00<00:00, 314378.16it/s] 100%|███████████████████| 9600/9600 [00:00<00:00, 370283.04it/s] 100%|█████████████████| 12000/12000 [00:00<00:00, 315266.39it/s] best epoch 0, best val acc=0.0000 + 0.0000 Traceback (most recent call last): File "/Users/sihanwang/Desktop/FEAT/train_fsl.py", line 19, in
trainer.train()
File "/Users/sihanwang/Desktop/FEAT/model/trainer/fsl_trainer.py", line 105, in train
self.try_evaluate(epoch)
File "/Users/sihanwang/Desktop/FEAT/model/trainer/base.py", line 53, in try_evaluate
vl, va, vap = self.evaluate(self.val_loader)
File "/Users/sihanwang/Desktop/FEAT/model/trainer/fsl_trainer.py", line 136, in evaluate
logits = self.model(data)
File "/Users/sihanwang/.conda/envs/input_bear_data.py/lib/python3.7/site-packages/torch/nn/modules/module.py", line 550, in call
result = self.forward(*input, **kwargs)
File "/Users/sihanwang/Desktop/FEAT/model/models/base.py", line 51, in forward
logits = self._forward(instance_embs, support_idx, query_idx)
File "/Users/sihanwang/Desktop/FEAT/model/models/semi_protofeat.py", line 135, in _forward
proto = self.get_proto(support, query) # we can also use adapted query set here to achieve better results
File "/Users/sihanwang/Desktop/FEAT/model/models/semi_protofeat.py", line 114, in get_proto
proto = torch.bmm(z.permute([0,2,1]), h)
RuntimeError: Expected tensor to have size 90 at dimension 1, but got size 80 for argument #2 'batch2' (while checking arguments for bmm)
I know this is because the dimension of z and h is wrong and the matrix cannot be multiplied, but I don't know how to fix it. I hope you can tell me if it is convenient, thank you. P.S. I change the ROOT_PATH from data/CUB/images to FEAT/data/CUB/images, otherwise the program will not be able to find the image data.