TylerYep / torchinfo

View model summaries in PyTorch!
MIT License
2.42k stars 115 forks source link

torchinfo.summary should not change the device of model #287

Open gitxw opened 7 months ago

gitxw commented 7 months ago

Describe the bug torchinfo.summary should not change the device of model

To Reproduce model = model.cpu() torchinfo.summary(model) why model was changed to cuda after calling torchinfo.summary() if I have a GraphicCard, It was Unexpected. I just want to dump the structure of my model.

Expected behavior torchinfo.summary should keep the status of model, not change it. the param "device" of the function torchinfo.summary is not necessary.

Desktop (please complete the following information):

1-ashraful-islam commented 1 month ago

Had same issue using torchinfo.summary Current workaround seems to be passing device as additional argument.

Code example:


use_cuda = torch.cuda.is_available()
use_mps = torch.backends.mps.is_available()

if use_cuda:
    device = torch.device("cuda")
elif use_mps:
    device = torch.device("mps")
else:
    device = torch.device("cpu")

print(f"Using device (CUDA/MPS/CPU): {device}")

# initialize the model
model = ConvModel().to(device)

# print model info
print(summary(model, input_size=(batch_size, 1, 28, 28), device = device))

Device: Mac with apple silicon