sksq96 / pytorch-summary

Model summary in PyTorch similar to `model.summary()` in Keras
MIT License
3.98k stars 412 forks source link

torch.Cat() isn't displayed in the summary, inner contents of a block are displayed. #152

Open gchhablani opened 3 years ago

gchhablani commented 3 years ago

My forward method is -


    def forward(self,x,skip):
        x = self.upSamp(x)
        print(x.shape)
        x = self.convRelu1(x)
        print(x.shape)
        x = torch.cat((x,skip),1)
        print(x.shape)
        return x

I get the following output for summary :

torch.Size([2, 64, 54, 54])
torch.Size([2, 128, 53, 53])
torch.Size([2, 160, 53, 53])
----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
================================================================
          Upsample-1           [-1, 64, 54, 54]               0
            Conv2d-2          [-1, 128, 53, 53]          32,896
              ReLU-3          [-1, 128, 53, 53]               0
          ConvReLU-4          [-1, 128, 53, 53]               0
================================================================

Clearly, the method ignores torch.cat() inside the forward method.

Also, it prints the name of the block (ConvReLU) after the inner components (Conv and Relu), in which case one of them should not be there (either the block name or the components).