amarczew / pytorch_model_summary

MIT License
52 stars 15 forks source link

Double Counting Parameters #3

Closed TortoiseHam closed 4 years ago

TortoiseHam commented 4 years ago

When a model takes multiple inputs (ex. a Siamese network) that all go through the same set of layers before the results are compared later on in the network, the model summary double-counts all of the shared parameters. For example, in the network below, if layerA has 10 parameters and layerB has 5, then it will report 25 parameters instead of 15.

x1 -> layerA --v ........................+-> layerB -> y x2 -> layerA --^

TortoiseHam commented 4 years ago

It also duplicates the layers in the summary, in addition to the parameter count. Here's the code for an example problem model:

https://github.com/fastestimator/fastestimator/blob/master/apphub/one_shot_learning/siamese_network/siamese_torch.py

amarczew commented 4 years ago

Hi @TortoiseHam!

Thank you for the reported issue! I understood the scenario but could you send a gists/code with an example that I can run quickly to see the problem?

TortoiseHam commented 4 years ago
import torch
import torch.nn as nn
import torch.nn.functional as fn
import pytorch_model_summary as pms

class SiameseNetwork(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(3, 64, 10)
        self.fc1 = nn.Linear(6400, 1)

    def branch_forward(self, x):
        x = self.conv1(x)
        x = fn.max_pool2d(x, 2)
        x = fn.relu(x)
        x = x.view(x.shape[0], -1)
        return x

    def forward(self, x):
        x1 = self.branch_forward(x[0])
        x2 = self.branch_forward(x[1])
        x = torch.abs(x1 - x2)
        x = self.fc1(x)
        x = torch.sigmoid(x)
        return x

if __name__ == "__main__":
    net = SiameseNetwork()
    img_a = torch.ones(size=(5,3,30,30))
    img_b = torch.zeros(size=(5,3,30,30))
    pms.summary(net, [img_a, img_b], print_summary=True)
TortoiseHam commented 4 years ago

which prints the following:

-----------------------------------------------------------------------
      Layer (type)        Output Shape         Param #     Tr. Param #
=======================================================================
          Conv2d-1     [5, 64, 21, 21]          19,264          19,264
          Conv2d-2     [5, 64, 21, 21]          19,264          19,264
          Linear-3              [5, 1]           6,401           6,401
=======================================================================
Total params: 44,929
Trainable params: 44,929
Non-trainable params: 0
-----------------------------------------------------------------------
TortoiseHam commented 4 years ago

Thanks for the reply by the way. I suspect that this could be fixed by memorizing the ids of the layers that have already been seen while crawling the model, but I haven't had time to study the internals of the repo yet.

amarczew commented 4 years ago

@TortoiseHam you are are right. Internally, I have a dictionary with layers to be printed. I just need to add a flag_print. I will fix this in few days. Thank you so much!

jonashaag commented 4 years ago

I also have this issue.

amarczew commented 4 years ago

hi @jonashaag , unfortunately in the last weeks I couldn't implement a fix for this problem. I will try to fix it next weekend.

Thank you for your report

amarczew commented 4 years ago

@jonashaag and @TortoiseHam issue resolved.

@TortoiseHam for your example, current output is:

-----------------------------------------------------------------------
      Layer (type)        Output Shape         Param #     Tr. Param #
=======================================================================
          Conv2d-1     [5, 64, 21, 21]          19,264          19,264
          Linear-2              [5, 1]           6,401           6,401
=======================================================================
Total params: 25,665
Trainable params: 25,665
Non-trainable params: 0
-----------------------------------------------------------------------
TortoiseHam commented 4 years ago

Perfect! Thanks for resolving this