Open gitxw opened 7 months ago
Had same issue using torchinfo.summary
Current workaround seems to be passing device as additional argument.
Code example:
use_cuda = torch.cuda.is_available()
use_mps = torch.backends.mps.is_available()
if use_cuda:
device = torch.device("cuda")
elif use_mps:
device = torch.device("mps")
else:
device = torch.device("cpu")
print(f"Using device (CUDA/MPS/CPU): {device}")
# initialize the model
model = ConvModel().to(device)
# print model info
print(summary(model, input_size=(batch_size, 1, 28, 28), device = device))
Device: Mac with apple silicon
Describe the bug torchinfo.summary should not change the device of model
To Reproduce model = model.cpu() torchinfo.summary(model) why model was changed to cuda after calling torchinfo.summary() if I have a GraphicCard, It was Unexpected. I just want to dump the structure of my model.
Expected behavior torchinfo.summary should keep the status of model, not change it. the param "device" of the function torchinfo.summary is not necessary.
Desktop (please complete the following information):