Open Skaifai opened 8 months ago
Pytorch summary does not work with the torch.device class. Code to reproduce the error.
import torch import torch.nn as nn from torchvision import models from torchsummary import summary device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") print("Using ", device) class CNN(nn.Module): def __init__(self, train_CNN=False, num_classes=2): super().__init__() self.conv1 = nn.Conv2d(3, 6, 5) self.pool = nn.MaxPool2d(2, 2) self.conv2 = nn.Conv2d(6, 16, 5) self.fc1 = nn.Linear(16 * 5 * 5, 120) self.fc2 = nn.Linear(120, 84) self.fc3 = nn.Linear(84, 10) def forward(self, x): x = self.pool(F.relu(self.conv1(x))) x = self.pool(F.relu(self.conv2(x))) x = torch.flatten(x, 1) # flatten all dimensions except batch x = F.relu(self.fc1(x)) x = F.relu(self.fc2(x)) x = self.fc3(x) return x model = CNN().to(device) summary(model, (3, 28, 28), device=device)
Error message:
--------------------------------------------------------------------------- AttributeError Traceback (most recent call last) ~\AppData\Local\Temp\ipykernel_17928\2345870344.py in <module> 27 28 model = CNN().to(device) ---> 29 summary(model, (3, 28, 28), device=device) ~\anaconda3\lib\site-packages\torchsummary\torchsummary.py in summary(model, input_size, batch_size, device) 42 hooks.append(module.register_forward_hook(hook)) 43 ---> 44 device = device.lower() 45 assert device in [ 46 "cuda", AttributeError: 'torch.device' object has no attribute 'lower'
Pytorch summary does not work with the torch.device class. Code to reproduce the error.
Error message: