badripatro / simba

Simba
156 stars 15 forks source link

Problem with torch.FloatTensor and torch.cuda.FloatTensor #15

Open duc-anh-2002 opened 1 month ago

duc-anh-2002 commented 1 month ago

Thank you for your repository, at this time, I have the problem as follows: " File "main_debug.py", line 495, in main train_stats = train_one_epoch( File "/home/ubuntu/workspace/mamba/simba/classification/engine.py", line 48, in train_one_epoch outputs = model(samples) File "/home/ubuntu/workspace/envs/simba/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1130, in _call_impl return forward_call(*input, kwargs) File "/home/ubuntu/workspace/mamba/simba/classification/simba_debug.py", line 603, in forward x, H, W = self.forward_embeddings(x) File "/home/ubuntu/workspace/mamba/simba/classification/simba_debug.py", line 659, in forward_embeddings x, H, W = patch_embed(x) File "/home/ubuntu/workspace/envs/simba/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1130, in _call_impl return forward_call(*input, *kwargs) File "/home/ubuntu/workspace/mamba/simba/classification/simba_debug.py", line 477, in forward x = self.conv(torch.FloatTensor(x)) File "/home/ubuntu/workspace/envs/simba/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1130, in _call_impl return forward_call(input, kwargs) File "/home/ubuntu/workspace/envs/simba/lib/python3.8/site-packages/torch/nn/modules/container.py", line 139, in forward input = module(input) File "/home/ubuntu/workspace/envs/simba/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1130, in _call_impl return forward_call(*input, **kwargs) File "/home/ubuntu/workspace/envs/simba/lib/python3.8/site-packages/torch/nn/modules/conv.py", line 457, in forward return self._conv_forward(input, self.weight, self.bias) File "/home/ubuntu/workspace/envs/simba/lib/python3.8/site-packages/torch/nn/modules/conv.py", line 453, in _conv_forward return F.conv2d(input, weight, bias, self.stride, RuntimeError: Input type (torch.FloatTensor) and weight type (torch.cuda.FloatTensor) should be the same or input should be a MKLDNN tensor and weight is a dense tensor" Could you help me to fix this bug, thank you