sksq96 / pytorch-summary

Model summary in PyTorch similar to `model.summary()` in Keras
MIT License
4.01k stars 412 forks source link

summary does not work with the torch.device class #199

Open Skaifai opened 8 months ago

Skaifai commented 8 months ago

Pytorch summary does not work with the torch.device class. Code to reproduce the error.

import torch
import torch.nn as nn
from torchvision import models
from torchsummary import summary

device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
print("Using ", device)

class CNN(nn.Module):
    def __init__(self, train_CNN=False, num_classes=2):
        super().__init__()
        self.conv1 = nn.Conv2d(3, 6, 5)
        self.pool = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(6, 16, 5)
        self.fc1 = nn.Linear(16 * 5 * 5, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)

    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = torch.flatten(x, 1) # flatten all dimensions except batch
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x

model = CNN().to(device)
summary(model, (3, 28, 28), device=device)

Error message:

---------------------------------------------------------------------------
AttributeError                            Traceback (most recent call last)
~\AppData\Local\Temp\ipykernel_17928\2345870344.py in <module>
     27 
     28 model = CNN().to(device)
---> 29 summary(model, (3, 28, 28), device=device)

~\anaconda3\lib\site-packages\torchsummary\torchsummary.py in summary(model, input_size, batch_size, device)
     42             hooks.append(module.register_forward_hook(hook))
     43 
---> 44     device = device.lower()
     45     assert device in [
     46         "cuda",

AttributeError: 'torch.device' object has no attribute 'lower'