sksq96 / pytorch-summary

Model summary in PyTorch similar to `model.summary()` in Keras
MIT License
4.01k stars 412 forks source link

RuntimeError: Input type (torch.cuda.FloatTensor) and weight type (torch.FloatTensor) should be the same #181

Closed rohan-paul closed 2 years ago

rohan-paul commented 2 years ago

After running the example code from the doc

import torch
import torch.nn as nn
import torch.nn.functional as F
from torchsummary import summary

class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(1, 10, kernel_size=5)
        self.conv2 = nn.Conv2d(10, 20, kernel_size=5)
        self.conv2_drop = nn.Dropout2d()
        self.fc1 = nn.Linear(320, 50)
        self.fc2 = nn.Linear(50, 10)

    def forward(self, x):
        x = F.relu(F.max_pool2d(self.conv1(x), 2))
        x = F.relu(F.max_pool2d(self.conv2_drop(self.conv2(x)), 2))
        x = x.view(-1, 320)
        x = F.relu(self.fc1(x))
        x = F.dropout(x, training=self.training)
        x = self.fc2(x)
        return F.log_softmax(x, dim=1)

# device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # PyTorch v0.4.0
device = "cpu"
model = Net().to(device)

summary(model, (1, 28, 28))

Getting

RuntimeError                              Traceback (most recent call last)
/tmp/ipykernel_108660/4081865440.py in <module>
     26 model = Net().to(device)
     27 
---> 28 summary(model, (1, 28, 28))

~/.local/lib/python3.9/site-packages/torchsummary/torchsummary.py in summary(model, input_size, batch_size, device)
     70     # make a forward pass
     71     # print(x.shape)
---> 72     model(*x)
     73 
     74     # remove these hooks

~/.local/lib/python3.9/site-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs)
   1100         if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
   1101                 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1102             return forward_call(*input, **kwargs)
   1103         # Do not call functions when jit is used
   1104         full_backward_hooks, non_full_backward_hooks = [], []

/tmp/ipykernel_108660/4081865440.py in forward(self, x)
     14 
     15     def forward(self, x):
---> 16         x = F.relu(F.max_pool2d(self.conv1(x), 2))
     17         x = F.relu(F.max_pool2d(self.conv2_drop(self.conv2(x)), 2))
     18         x = x.view(-1, 320)

~/.local/lib/python3.9/site-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs)
   1118             input = bw_hook.setup_input_hook(input)
   1119 
-> 1120         result = forward_call(*input, **kwargs)
   1121         if _global_forward_hooks or self._forward_hooks:
   1122             for hook in (*_global_forward_hooks.values(), *self._forward_hooks.values()):

~/.local/lib/python3.9/site-packages/torch/nn/modules/conv.py in forward(self, input)
    444 
    445     def forward(self, input: Tensor) -> Tensor:
--> 446         return self._conv_forward(input, self.weight, self.bias)
    447 
    448 class Conv3d(_ConvNd):

~/.local/lib/python3.9/site-packages/torch/nn/modules/conv.py in _conv_forward(self, input, weight, bias)
    440                             weight, bias, self.stride,
    441                             _pair(0), self.dilation, self.groups)
--> 442         return F.conv2d(input, weight, bias, self.stride,
    443                         self.padding, self.dilation, self.groups)
    444 

RuntimeError: Input type (torch.cuda.FloatTensor) and weight type (torch.FloatTensor) should be the same
m-zheng commented 2 years ago

Try this: summary(model, (1, 28, 28), device='cpu')

rohan-paul commented 2 years ago

That worked, thanks... closing.