sksq96 / pytorch-summary

Model summary in PyTorch similar to `model.summary()` in Keras
MIT License
3.98k stars 412 forks source link

Why do I get '2' as batch size? #168

Open Flock1 opened 3 years ago

Flock1 commented 3 years ago

Hey,

This is a really great tool to visualize the model. However, I was trying to see how my decoder is working in the VAE and the input to the VAE is the latent space (dim = (2,2)). However, when I get the output, I see an extra 2 there. Like this: summary(decoder, (2,2)) Output is:

DECODER
torch.Size([2, 2, 2])

My decoder is initialized like this:

class Decoder(nn.Module):
    def __init__(self):
        super(Decoder, self).__init__()
        c = capacity
        self.fc = nn.Linear(in_features=latent_dims, out_features=c*2*7*7)
        self.conv2 = nn.ConvTranspose2d(in_channels=c*2, out_channels=c, kernel_size=4, stride=2, padding=1)
        self.conv1 = nn.ConvTranspose2d(in_channels=c, out_channels=1, kernel_size=4, stride=2, padding=1)
        self.adapt = nn.AdaptiveMaxPool1d(input_len)

    def forward(self, x):
        print("DECODER")
        print(x.shape) #1
        x = self.fc(x)
        x = x.reshape(-1,x.shape[0], x.shape[1])
        x = self.adapt(x)
        x = x.view(x.size(0), capacity*2, axis_transfer, axis_transfer) # unflatten batch of feature vectors to a batch of multi-channel feature maps
        x = F.relu(self.conv2(x))
        x = torch.sigmoid(self.conv1(x)) # last layer before output is sigmoid, since we are using BCE as reconstruction loss
        return x

Do let me know.

cainmagi commented 3 years ago

torchsummary would use a batch size 2 tensor to test the network, and get the information of each layer. See the codes here: https://github.com/sksq96/pytorch-summary/blob/011b2bd0ec7153d5842c1b37d1944fc6a7bf5feb/torchsummary/torchsummary.py#L58

Even you configure batch_size in the input argument, this value is only used for calculating the flow size. The network is still tested by the batch size 2 tensor.

This behavior may cause errors when the network requires the input batch to be a specific value. To fix this problem, I modify the codes and let the testing tensor use batch_size when this value is not None, see https://github.com/sksq96/pytorch-summary/pull/165/files#diff-ebda1cc7f304708e45ef4e19fb0484036eff8eb3c4b47a2886ca1cf0f731c0bbR118

Actually, it seems that the author has not maintained this package for a long time. I recommend you to try some alternatives like torchinfo.

Flock1 commented 3 years ago

Thanks a lot. I wanted to ask why does it take '1' as batch size when I input a shape similar to an image, like (3,28,28)? Because in that case, I don't see '2' as batch size.

I will definitely check out torchinfo

cainmagi commented 3 years ago

Thanks a lot. I wanted to ask why does it take '1' as batch size when I input a shape similar to an image, like (3,28,28)? Because in that case, I don't see '2' as batch size.

I will definitely check out torchinfo

I do not understand your question. In your previous posts, you have not mentioned any batch with a batch size of 1.

By the way, I do not understand

I don't see '2' as batch size.

either. Because you have mentioned that your output is

DECODER
torch.Size([2, 2, 2])

Why do you say you do not see 2 as batch size? It is clear that the first element of the returned shape is 2.


Here is a tip: if your are using

torchsummary.summary(..., input_size=...)

You should not let your input_size become something like [3, 28, 28]. That would cause errors. Instead, you should use ((3, 28, 28), ) or (3, 28, 28). The official implementation is quite unstable in some cases.

letsgo247 commented 3 years ago

@cainmagi Thanks! torchinfo works!