microsoft / torchscale

Foundation Architecture for (M)LLMs
https://aka.ms/GeneralAI
MIT License
2.98k stars 201 forks source link

Compatibility with torchsummary #71

Closed lzqlzzq closed 9 months ago

lzqlzzq commented 10 months ago

Here are the versions of torchscale and torchsummary in my environment:

torchscale                    0.2.0
torchsummary                  1.5.1

I am using my custom embedding to achieve an Auto-Regression task so I wrap the torchscale.architecture.decoder.Decoder with the following code:

from torchscale.architecture.decoder import Decoder
from torchscale.architecture.config import DecoderConfig
from torchsummary import summary

class LatentDecoder(nn.Module):
    def __init__(self, **kwargs):
        super().__init__()
        self.decoder = Decoder(**kwargs)

    def forward(self, x):
        y = self.decoder(prev_output_tokens=torch.zeros((1,1)),  # Not used when self_attn_relative_position is None
                         token_embeddings=x,
                         eatures_only=True)
        return y

dec_config = DecoderConfig(
    subln=True, # use sublayer normalization
    dropout=0.1,
    drop_path_rate=0.1,
    decoder_layers=6,
    decoder_embed_dim=1024,
    decoder_ffn_embed_dim=2048,
    decoder_attention_heads=8
)

decoder = LatentDecoder(args=dec_config, is_encoder_decoder=True)

When I tried to use torchsummary to get a summary of the model with these codes:

input_size = (16, 4, 1024)  # (batch_size, token_index, embedding_size)
summary(decoder, input_size=input_size)

I got error:

ValueError                                Traceback (most recent call last)
Cell In[18], line 27
     24 input_size = (16, 4, 1024)  # (batch_size, token_index, embedding_size)
     25 decoder = LatentDecoder(args=dec_config, is_encoder_decoder=True)
---> 27 summary(decoder, input_size=input_size)

File ~/.local/lib/python3.11/site-packages/torchsummary/torchsummary.py:72, in summary(model, input_size, batch_size, device)
     68 model.apply(register_hook)
     70 # make a forward pass
     71 # print(x.shape)
---> 72 model(*x)
     74 # remove these hooks
     75 for h in hooks:

File ~/.local/lib/python3.11/site-packages/torch/nn/modules/module.py:1501, in Module._call_impl(self, *args, **kwargs)
   1496 # If we don't have any hooks, we want to skip the rest of the logic in
   1497 # this function, and just call forward.
   1498 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1499         or _global_backward_pre_hooks or _global_backward_hooks
   1500         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1501     return forward_call(*args, **kwargs)
   1502 # Do not call functions when jit is used
   1503 full_backward_hooks, non_full_backward_hooks = [], []

Cell In[18], line 11, in LatentDecoder.forward(self, x)
     10 def forward(self, x):
---> 11     y = self.decoder(prev_output_tokens=torch.zeros((1,1)), token_embeddings=x, features_only=True)
     12     return y

File ~/.local/lib/python3.11/site-packages/torch/nn/modules/module.py:1538, in Module._call_impl(self, *args, **kwargs)
   1535     bw_hook = hooks.BackwardHook(self, full_backward_hooks, backward_pre_hooks)
   1536     args = bw_hook.setup_input_hook(args)
-> 1538 result = forward_call(*args, **kwargs)
   1539 if _global_forward_hooks or self._forward_hooks:
   1540     for hook_id, hook in (
   1541         *_global_forward_hooks.items(),
   1542         *self._forward_hooks.items(),
   1543     ):

File ~/.local/lib/python3.11/site-packages/torchscale/architecture/decoder.py:437, in Decoder.forward(self, prev_output_tokens, self_attn_padding_mask, encoder_out, incremental_state, features_only, return_all_hiddens, token_embeddings, **kwargs)
    434     if idx not in incremental_state:
    435         incremental_state[idx] = {}
--> 437 x, layer_attn, _, l_aux_i = layer(
    438     x,
    439     encoder_out["encoder_out"] if encoder_out is not None else None,
    440     encoder_out["encoder_padding_mask"]
    441     if encoder_out is not None
    442     else None,
    443     incremental_state[idx] if incremental_state is not None else None,
    444     self_attn_mask=self_attn_mask,
    445     self_attn_padding_mask=self_attn_padding_mask,
    446     self_attn_rel_pos=self_attn_rel_pos_bias,
    447     cross_attn_rel_pos=cross_attn_rel_pos_bias,
    448 )
    449 l_aux.append(l_aux_i)
    450 inner_states.append(x)

File ~/.local/lib/python3.11/site-packages/torch/nn/modules/module.py:1538, in Module._call_impl(self, *args, **kwargs)
   1535     bw_hook = hooks.BackwardHook(self, full_backward_hooks, backward_pre_hooks)
   1536     args = bw_hook.setup_input_hook(args)
-> 1538 result = forward_call(*args, **kwargs)
   1539 if _global_forward_hooks or self._forward_hooks:
   1540     for hook_id, hook in (
   1541         *_global_forward_hooks.items(),
   1542         *self._forward_hooks.items(),
   1543     ):

File ~/.local/lib/python3.11/site-packages/torchscale/architecture/decoder.py:148, in DecoderLayer.forward(self, x, encoder_out, encoder_padding_mask, incremental_state, self_attn_mask, self_attn_padding_mask, self_attn_rel_pos, cross_attn_rel_pos)
    145 if self.normalize_before:
    146     x = self.self_attn_layer_norm(x)
--> 148 x, attn = self.self_attn(
    149     query=x,
    150     key=x,
    151     value=x,
    152     key_padding_mask=self_attn_padding_mask,
    153     incremental_state=incremental_state,
    154     attn_mask=self_attn_mask,
    155     rel_pos=self_attn_rel_pos,
    156 )
    157 x = self.dropout_module(x)
    159 if self.drop_path is not None:

File ~/.local/lib/python3.11/site-packages/torch/nn/modules/module.py:1538, in Module._call_impl(self, *args, **kwargs)
   1535     bw_hook = hooks.BackwardHook(self, full_backward_hooks, backward_pre_hooks)
   1536     args = bw_hook.setup_input_hook(args)
-> 1538 result = forward_call(*args, **kwargs)
   1539 if _global_forward_hooks or self._forward_hooks:
   1540     for hook_id, hook in (
   1541         *_global_forward_hooks.items(),
   1542         *self._forward_hooks.items(),
   1543     ):

File ~/.local/lib/python3.11/site-packages/torchscale/component/multihead_attention.py:75, in MultiheadAttention.forward(self, query, key, value, incremental_state, key_padding_mask, attn_mask, rel_pos)
     65 def forward(
     66     self,
     67     query,
   (...)
     73     rel_pos=None,
     74 ):
---> 75     bsz, tgt_len, embed_dim = query.size()
     76     src_len = tgt_len
     77     assert embed_dim == self.embed_dim, f"query dim {embed_dim} != {self.embed_dim}"

ValueError: too many values to unpack (expected 3)

Then I tried to omit the batch_size dim:

input_size = (4, 1024)  # (token_index, embedding_size)

I then had another error:

IndexError                                Traceback (most recent call last)
Cell In[19], line 27
     24 input_size = (4, 1024)  # (token_index, embedding_size)
     25 decoder = LatentDecoder(args=dec_config, is_encoder_decoder=True)
---> 27 summary(decoder, input_size=input_size)

File ~/.local/lib/python3.11/site-packages/torchsummary/torchsummary.py:72, in summary(model, input_size, batch_size, device)
     68 model.apply(register_hook)
     70 # make a forward pass
     71 # print(x.shape)
---> 72 model(*x)
     74 # remove these hooks
     75 for h in hooks:

File ~/.local/lib/python3.11/site-packages/torch/nn/modules/module.py:1501, in Module._call_impl(self, *args, **kwargs)
   1496 # If we don't have any hooks, we want to skip the rest of the logic in
   1497 # this function, and just call forward.
   1498 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1499         or _global_backward_pre_hooks or _global_backward_hooks
   1500         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1501     return forward_call(*args, **kwargs)
   1502 # Do not call functions when jit is used
   1503 full_backward_hooks, non_full_backward_hooks = [], []

Cell In[19], line 11, in LatentDecoder.forward(self, x)
     10 def forward(self, x):
---> 11     y = self.decoder(prev_output_tokens=torch.zeros((1,1)), token_embeddings=x, features_only=True)
     12     return y

File ~/.local/lib/python3.11/site-packages/torch/nn/modules/module.py:1538, in Module._call_impl(self, *args, **kwargs)
   1535     bw_hook = hooks.BackwardHook(self, full_backward_hooks, backward_pre_hooks)
   1536     args = bw_hook.setup_input_hook(args)
-> 1538 result = forward_call(*args, **kwargs)
   1539 if _global_forward_hooks or self._forward_hooks:
   1540     for hook_id, hook in (
   1541         *_global_forward_hooks.items(),
   1542         *self._forward_hooks.items(),
   1543     ):

File ~/.local/lib/python3.11/site-packages/torchscale/architecture/decoder.py:437, in Decoder.forward(self, prev_output_tokens, self_attn_padding_mask, encoder_out, incremental_state, features_only, return_all_hiddens, token_embeddings, **kwargs)
    434     if idx not in incremental_state:
    435         incremental_state[idx] = {}
--> 437 x, layer_attn, _, l_aux_i = layer(
    438     x,
    439     encoder_out["encoder_out"] if encoder_out is not None else None,
    440     encoder_out["encoder_padding_mask"]
    441     if encoder_out is not None
    442     else None,
    443     incremental_state[idx] if incremental_state is not None else None,
    444     self_attn_mask=self_attn_mask,
    445     self_attn_padding_mask=self_attn_padding_mask,
    446     self_attn_rel_pos=self_attn_rel_pos_bias,
    447     cross_attn_rel_pos=cross_attn_rel_pos_bias,
    448 )
    449 l_aux.append(l_aux_i)
    450 inner_states.append(x)

File ~/.local/lib/python3.11/site-packages/torch/nn/modules/module.py:1538, in Module._call_impl(self, *args, **kwargs)
   1535     bw_hook = hooks.BackwardHook(self, full_backward_hooks, backward_pre_hooks)
   1536     args = bw_hook.setup_input_hook(args)
-> 1538 result = forward_call(*args, **kwargs)
   1539 if _global_forward_hooks or self._forward_hooks:
   1540     for hook_id, hook in (
   1541         *_global_forward_hooks.items(),
   1542         *self._forward_hooks.items(),
   1543     ):

File ~/.local/lib/python3.11/site-packages/torchscale/architecture/decoder.py:148, in DecoderLayer.forward(self, x, encoder_out, encoder_padding_mask, incremental_state, self_attn_mask, self_attn_padding_mask, self_attn_rel_pos, cross_attn_rel_pos)
    145 if self.normalize_before:
    146     x = self.self_attn_layer_norm(x)
--> 148 x, attn = self.self_attn(
    149     query=x,
    150     key=x,
    151     value=x,
    152     key_padding_mask=self_attn_padding_mask,
    153     incremental_state=incremental_state,
    154     attn_mask=self_attn_mask,
    155     rel_pos=self_attn_rel_pos,
    156 )
    157 x = self.dropout_module(x)
    159 if self.drop_path is not None:

File ~/.local/lib/python3.11/site-packages/torch/nn/modules/module.py:1547, in Module._call_impl(self, *args, **kwargs)
   1545     hook_result = hook(self, args, kwargs, result)
   1546 else:
-> 1547     hook_result = hook(self, args, result)
   1549 if hook_result is not None:
   1550     result = hook_result

File ~/.local/lib/python3.11/site-packages/torchsummary/torchsummary.py:19, in summary.<locals>.register_hook.<locals>.hook(module, input, output)
     17 m_key = "%s-%i" % (class_name, module_idx + 1)
     18 summary[m_key] = OrderedDict()
---> 19 summary[m_key]["input_shape"] = list(input[0].size())
     20 summary[m_key]["input_shape"][0] = batch_size
     21 if isinstance(output, (list, tuple)):

IndexError: tuple index out of range

Why it happens and how should I do to fix it?

Thanks!

donglixp commented 9 months ago

It seems like a bug of torchsummary