sksq96 / pytorch-summary

Model summary in PyTorch similar to `model.summary()` in Keras
MIT License
3.98k stars 412 forks source link

Hello ! I modify the way to count parameters #125

Closed a45s67 closed 4 years ago

a45s67 commented 4 years ago

Hi everyone. I found that the origin code only considers the parameters named "weight" and "bias". When we want to make some custom modules, like :

class NoiseLinear(nn.Module):
    def __init__(self,inf,outf):
        super(NoiseLinear,self).__init__()
        self.w_mu = nn.Parameter(torch.Tensor(outf,inf))
        self.w_sig = nn.Parameter(torch.Tensor(outf,inf))
        self.b_mu = nn.Parameter(torch.Tensor(outf))
        self.b_sig = nn.Parameter(torch.Tensor(outf))

        # no parameters named "weight" and "bias" here

The original counting method will count 0 parameters. So I modify the code to count based on parameters() iterator. It looks good to me after using for a while.