AlessioGalluccio / FastFlow

an implementation of the architecture of FastFlow (Jiawei Yu et al.)
MIT License
40 stars 13 forks source link

RuntimeError: Input type (torch.cuda.FloatTensor) and weight type (torch.FloatTensor) should be the same #5

Closed cyj95 closed 2 years ago

cyj95 commented 2 years ago

Traceback (most recent call last): File "main.py", line 29, in model = train(train_loader, test_loader) File "/home/robin/cuiyajie/FastFlow-master/train.py", line 65, in train model = FastFlow() File "/home/robin/cuiyajie/FastFlow-master/model.py", line 104, in init print(summary(self.feature_extractor, (3,384,384))) File "/home/robin/anaconda3/envs/patchcore/lib/python3.7/site-packages/torchsummary/torchsummary.py", line 72, in summary model(x) File "/home/robin/anaconda3/envs/patchcore/lib/python3.7/site-packages/torch/nn/modules/module.py", line 1051, in _call_impl return forward_call(input, kwargs) File "/home/robin/anaconda3/envs/patchcore/lib/python3.7/site-packages/torch/nn/modules/container.py", line 139, in forward input = module(input) File "/home/robin/anaconda3/envs/patchcore/lib/python3.7/site-packages/torch/nn/modules/module.py", line 1071, in _call_impl result = forward_call(*input, *kwargs) File "/home/robin/anaconda3/envs/patchcore/lib/python3.7/site-packages/timm/models/layers/patch_embed.py", line 35, in forward x = self.proj(x) File "/home/robin/anaconda3/envs/patchcore/lib/python3.7/site-packages/torch/nn/modules/module.py", line 1071, in _call_impl result = forward_call(input, kwargs) File "/home/robin/anaconda3/envs/patchcore/lib/python3.7/site-packages/torch/nn/modules/conv.py", line 443, in forward return self._conv_forward(input, self.weight, self.bias) File "/home/robin/anaconda3/envs/patchcore/lib/python3.7/site-packages/torch/nn/modules/conv.py", line 440, in _conv_forward self.padding, self.dilation, self.groups) RuntimeError: Input type (torch.cuda.FloatTensor) and weight type (torch.FloatTensor) should be the same

simo-an commented 2 years ago

update code to this in model.py

            self.feature_extractor.to(c.device)
            print(summary(self.feature_extractor, (3,256,256), device=c.device))
AlessioGalluccio commented 2 years ago

Thank you very much, I changed the code. Let me know if you still have this problem