Closed t4rf9 closed 1 year ago
Have you tried using the device
parameter to summary?
I have sucessfully replicated the error.
Running
from torch import nn
import torch
from torchinfo import summary
linear = nn.Linear(1000, 1000).to(1)
a = torch.randn(32, 1000).to(1)
_ = summary(linear, input_data=a)
next(linear.parameters()).device
yields:
device(type='cuda', index=0)
And it turns out that using the device
parameter works just fine:
from torch import nn
import torch
from torchinfo import summary
linear = nn.Linear(1000, 1000).to(1)
a = torch.randn(32, 1000).to(1)
_ = summary(linear, input_data=a, device=torch.device("cuda:1"))
next(linear.parameters()).device
yields:
device(type='cuda', index=1)
Still, this issue is adressed in pull-request#211.
Describe the bug For CPU usage, torchinfo.summary works well. For models on GPU, if the model were not on GPU 0 (e.g. on GPU 5), torchinfo.summary will move it to GPU 0.
To Reproduce Steps to reproduce the behavior:
linear = nn.Linear(1000, 1000).to(1) a = torch.randn(32, 1000).to(1) summary(linear, input_data=a)