sksq96 / pytorch-summary

Model summary in PyTorch similar to `model.summary()` in Keras
MIT License
3.98k stars 412 forks source link

feat: Default device is set to model device #117

Open frgfm opened 4 years ago

frgfm commented 4 years ago

Avoids specifying device since the input tensor needs to be on the same on as the model. This is useful in multi-GPUs environment or to freely use the function on CPU.

Previously

from torchvision.models import resnet18
from torchsummary import summary
model = resnet18().eval()
summary(model, (3, 224, 224))

would fairly yield

RuntimeError: Input type (torch.cuda.FloatTensor) and weight type (torch.FloatTensor) should be the same

The device is now dynamically set to the model device.

Any feedback is welcome, cheers!

leriomaggio commented 4 years ago

The current implementation in torchsummary is not backward compatible with the 1.5.1 version (installed via pip) (in which the summary_string function is not available).

Current implementation assumes CUDA by default - which is not my case (for example) using torch on Mac.

This PR is very flexible adapting the device parameter default value to the one used by the model: very neat and elegant solution! Well done!

So totally 👍 for me!

cc/ @sksq96 @frgfm

chrismaliszewski commented 3 years ago

I'd like to add my one cent to the discussion. I first encountered a problem of a default device for the module being GPU while I just have a CPU. I modified the module's code as follows:

def summary(model, input_size, batch_size=-1, device={}, dtypes=None):
    if not device: device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    result, params_info = summary_string(
        model, input_size, batch_size, device, dtypes)
    print(result)

    return params_info

def summary_string(model, input_size, batch_size=-1, device={}, dtypes=None):
    if not device: device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
...

but it created a very similar problem to the one @frgfm had ("Input type (torch.FloatTensor) and weight type (torch.DoubleTensor) should be the same"). So here is another fix to the problem:

def summary_string(model, input_size, batch_size=-1, device={}, dtypes=None):
    if not device: device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    if dtypes == None:
        dtypes = [torch.Tensor().to(device).type()]*len(input_size)

So instead of initializing torch.FloatTensor as a default one, create a tensor and check its type on the device. That will prevent the problem from happening.

frgfm commented 3 years ago

Hi @chrismaliszewski,

Thanks for the suggestion! I had that in mind too, for personal use, I made a whole new python package, which I maintain very actively (and added a few other features, including experimental receptive field computation). Here it is: https://github.com/frgfm/torch-scan

Feel free to drop suggestions or issues, hope that helps!