sksq96 / pytorch-summary

Model summary in PyTorch similar to `model.summary()` in Keras
MIT License
4.01k stars 412 forks source link

Hello !! I modify the method of counting parameters #126

Closed a45s67 closed 1 year ago

a45s67 commented 4 years ago

Hi everyone. I found that the origin code only considers the parameters named "weight" and "bias". When we want to make some custom modules, like :

class NoiseLinear(nn.Module):
    def __init__(self,inf,outf):
        super(NoiseLinear,self).__init__()
        self.w_mu = nn.Parameter(torch.Tensor(outf,inf))
        self.w_sig = nn.Parameter(torch.Tensor(outf,inf))
        self.b_mu = nn.Parameter(torch.Tensor(outf))
        self.b_sig = nn.Parameter(torch.Tensor(outf))

        # no parameters named "weight" and "bias" here

The original counting method will count 0 parameters. So I modify the code to count based on parameters() iterator. It looks good to me after using for a while.

luisherrmann commented 4 years ago

Hi, good call! I have been thinking of doing the same thing. However, instead of using prod() on the tensor sizes, I would just directly use the numel() function of nn.parameter.Parameter.

I have made a separate pull request with a couple other suggestions and additional tests.