TylerYep / torchinfo

View model summaries in PyTorch!
MIT License
2.48k stars 117 forks source link

Torchinfo does not take sequence length for LSTM models into account for computing MACs #32

Closed StefanUhlich-sony closed 3 years ago

StefanUhlich-sony commented 3 years ago

First of all, great work @TylerYep - torchinfo looks very nice :)

I have a question regarding the MAC computation for LSTM networks. Somehow it seems that the sequence length is not taken into account.

Here is the example from the README.md with two different sequence lengths:

class LSTMNet(nn.Module):
    """ Batch-first LSTM model. """
    def __init__(self, vocab_size=20, embed_dim=300, hidden_dim=512, num_layers=2):
        super().__init__()
        self.hidden_dim = hidden_dim
        self.embedding = nn.Embedding(vocab_size, embed_dim)
        self.encoder = nn.LSTM(embed_dim, hidden_dim, num_layers=num_layers, batch_first=True)
        self.decoder = nn.Linear(hidden_dim, vocab_size)

    def forward(self, x):
        embed = self.embedding(x)
        out, hidden = self.encoder(embed)
        out = self.decoder(out)
        out = out.view(-1, out.size(2))
        return out, hidden

seq_length = 100
summary(
    LSTMNet(),
    (1, seq_length),
    dtypes=[torch.long],
    verbose=2,
    col_width=16,
    col_names=["kernel_size", "output_size", "num_params", "mult_adds"],
)

seq_length = 10
summary(
    LSTMNet(),
    (1, seq_length),
    dtypes=[torch.long],
    verbose=2,
    col_width=16,
    col_names=["kernel_size", "output_size", "num_params", "mult_adds"],
)

This gives the following output:

========================================================================================================
Layer (type:depth-idx)                   Kernel Shape     Output Shape     Param #          Mult-Adds
========================================================================================================
\u251c\u2500Embedding: 1-1                         [300, 20]        [1, 100, 300]    6,000            6,000
\u251c\u2500LSTM: 1-2                              --               [1, 100, 512]    3,768,320        3,760,128
|    \u2514\u2500weight_ih_l0                      [2048, 300]
|    \u2514\u2500weight_hh_l0                      [2048, 512]
|    \u2514\u2500weight_ih_l1                      [2048, 512]
|    \u2514\u2500weight_hh_l1                      [2048, 512]
\u251c\u2500Linear: 1-3                            [512, 20]        [1, 100, 20]     10,260           10,240
========================================================================================================
Total params: 3,784,580
Trainable params: 3,784,580
Non-trainable params: 0
Total mult-adds (M): 3.78
========================================================================================================
Input size (MB): 0.00
Forward/backward pass size (MB): 0.67
Params size (MB): 15.14
Estimated Total Size (MB): 15.80
========================================================================================================

========================================================================================================
Layer (type:depth-idx)                   Kernel Shape     Output Shape     Param #          Mult-Adds
========================================================================================================
\u251c\u2500Embedding: 1-1                         [300, 20]        [1, 10, 300]     6,000            6,000
\u251c\u2500LSTM: 1-2                              --               [1, 10, 512]     3,768,320        3,760,128
|    \u2514\u2500weight_ih_l0                      [2048, 300]
|    \u2514\u2500weight_hh_l0                      [2048, 512]
|    \u2514\u2500weight_ih_l1                      [2048, 512]
|    \u2514\u2500weight_hh_l1                      [2048, 512]
\u251c\u2500Linear: 1-3                            [512, 20]        [1, 10, 20]      10,260           10,240
========================================================================================================
Total params: 3,784,580
Trainable params: 3,784,580
Non-trainable params: 0
Total mult-adds (M): 3.78
========================================================================================================
Input size (MB): 0.00
Forward/backward pass size (MB): 0.07
Params size (MB): 15.14
Estimated Total Size (MB): 15.20
========================================================================================================

Both times, the Total mult-adds (M) is the same, although the sequence length is different. I think the problem is that torchinfo ignores the seq_length, i.e., does not take into account for example the nn.Linear is used seq_length times. Is this behavior expected?

TylerYep commented 3 years ago

Yes, this looks like a current bug with torchinfo. If you look at the calculate_macs function, you will see that the calculation is not very robust, and was designed primarily for ConvNets.

As of now, I'm not sure how I would solve this issue. If you would like to try fixing this, I would appreciate a PR!