Sha-Lab / FEAT

The code repository for "Few-Shot Learning via Embedding Adaptation with Set-to-Set Functions"
MIT License
421 stars 84 forks source link

a question about trainning error #36

Closed wongsihan closed 4 years ago

wongsihan commented 4 years ago

Hello, I am a new student in deep learning. I have a simple problem that I cannot solve at present. I would like to ask for your advice。 I download the dataset CUB and miniimagenet, put it in the right palce. When I run train_fsl.py, the program reported an error.

100%|█████████████████| 38400/38400 [00:00<00:00, 314378.16it/s] 100%|███████████████████| 9600/9600 [00:00<00:00, 370283.04it/s] 100%|█████████████████| 12000/12000 [00:00<00:00, 315266.39it/s] best epoch 0, best val acc=0.0000 + 0.0000 Traceback (most recent call last): File "/Users/sihanwang/Desktop/FEAT/train_fsl.py", line 19, in trainer.train() File "/Users/sihanwang/Desktop/FEAT/model/trainer/fsl_trainer.py", line 105, in train self.try_evaluate(epoch) File "/Users/sihanwang/Desktop/FEAT/model/trainer/base.py", line 53, in try_evaluate vl, va, vap = self.evaluate(self.val_loader) File "/Users/sihanwang/Desktop/FEAT/model/trainer/fsl_trainer.py", line 136, in evaluate logits = self.model(data) File "/Users/sihanwang/.conda/envs/input_bear_data.py/lib/python3.7/site-packages/torch/nn/modules/module.py", line 550, in call result = self.forward(*input, **kwargs) File "/Users/sihanwang/Desktop/FEAT/model/models/base.py", line 51, in forward logits = self._forward(instance_embs, support_idx, query_idx) File "/Users/sihanwang/Desktop/FEAT/model/models/semi_protofeat.py", line 135, in _forward proto = self.get_proto(support, query) # we can also use adapted query set here to achieve better results File "/Users/sihanwang/Desktop/FEAT/model/models/semi_protofeat.py", line 114, in get_proto proto = torch.bmm(z.permute([0,2,1]), h) RuntimeError: Expected tensor to have size 90 at dimension 1, but got size 80 for argument #2 'batch2' (while checking arguments for bmm)

I know this is because the dimension of z and h is wrong and the matrix cannot be multiplied, but I don't know how to fix it. I hope you can tell me if it is convenient, thank you. P.S. I change the ROOT_PATH from data/CUB/images to FEAT/data/CUB/images, otherwise the program will not be able to find the image data.

Han-Jia commented 4 years ago

Hi,

The main reason here is that the hyper-parameters are not right. (The default parameters uses meta-training shot = 3 and meta-val shot = 1), so that the get_proto function could not reshape the tensor right.

I have fixed this bug (and also change one parameter in get_proto)

Please also note that the "SemiProtoFEAT" implements a transductive learning method. Please use FEAT for inductive few-shot learning.

wongsihan commented 4 years ago

Thank you for your reply!