Closed lzqlzzq closed 9 months ago
Here are the versions of torchscale and torchsummary in my environment:
torchscale 0.2.0 torchsummary 1.5.1
I am using my custom embedding to achieve an Auto-Regression task so I wrap the torchscale.architecture.decoder.Decoder with the following code:
from torchscale.architecture.decoder import Decoder from torchscale.architecture.config import DecoderConfig from torchsummary import summary class LatentDecoder(nn.Module): def __init__(self, **kwargs): super().__init__() self.decoder = Decoder(**kwargs) def forward(self, x): y = self.decoder(prev_output_tokens=torch.zeros((1,1)), # Not used when self_attn_relative_position is None token_embeddings=x, eatures_only=True) return y dec_config = DecoderConfig( subln=True, # use sublayer normalization dropout=0.1, drop_path_rate=0.1, decoder_layers=6, decoder_embed_dim=1024, decoder_ffn_embed_dim=2048, decoder_attention_heads=8 ) decoder = LatentDecoder(args=dec_config, is_encoder_decoder=True)
When I tried to use torchsummary to get a summary of the model with these codes:
input_size = (16, 4, 1024) # (batch_size, token_index, embedding_size) summary(decoder, input_size=input_size)
I got error:
ValueError Traceback (most recent call last) Cell In[18], line 27 24 input_size = (16, 4, 1024) # (batch_size, token_index, embedding_size) 25 decoder = LatentDecoder(args=dec_config, is_encoder_decoder=True) ---> 27 summary(decoder, input_size=input_size) File ~/.local/lib/python3.11/site-packages/torchsummary/torchsummary.py:72, in summary(model, input_size, batch_size, device) 68 model.apply(register_hook) 70 # make a forward pass 71 # print(x.shape) ---> 72 model(*x) 74 # remove these hooks 75 for h in hooks: File ~/.local/lib/python3.11/site-packages/torch/nn/modules/module.py:1501, in Module._call_impl(self, *args, **kwargs) 1496 # If we don't have any hooks, we want to skip the rest of the logic in 1497 # this function, and just call forward. 1498 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks 1499 or _global_backward_pre_hooks or _global_backward_hooks 1500 or _global_forward_hooks or _global_forward_pre_hooks): -> 1501 return forward_call(*args, **kwargs) 1502 # Do not call functions when jit is used 1503 full_backward_hooks, non_full_backward_hooks = [], [] Cell In[18], line 11, in LatentDecoder.forward(self, x) 10 def forward(self, x): ---> 11 y = self.decoder(prev_output_tokens=torch.zeros((1,1)), token_embeddings=x, features_only=True) 12 return y File ~/.local/lib/python3.11/site-packages/torch/nn/modules/module.py:1538, in Module._call_impl(self, *args, **kwargs) 1535 bw_hook = hooks.BackwardHook(self, full_backward_hooks, backward_pre_hooks) 1536 args = bw_hook.setup_input_hook(args) -> 1538 result = forward_call(*args, **kwargs) 1539 if _global_forward_hooks or self._forward_hooks: 1540 for hook_id, hook in ( 1541 *_global_forward_hooks.items(), 1542 *self._forward_hooks.items(), 1543 ): File ~/.local/lib/python3.11/site-packages/torchscale/architecture/decoder.py:437, in Decoder.forward(self, prev_output_tokens, self_attn_padding_mask, encoder_out, incremental_state, features_only, return_all_hiddens, token_embeddings, **kwargs) 434 if idx not in incremental_state: 435 incremental_state[idx] = {} --> 437 x, layer_attn, _, l_aux_i = layer( 438 x, 439 encoder_out["encoder_out"] if encoder_out is not None else None, 440 encoder_out["encoder_padding_mask"] 441 if encoder_out is not None 442 else None, 443 incremental_state[idx] if incremental_state is not None else None, 444 self_attn_mask=self_attn_mask, 445 self_attn_padding_mask=self_attn_padding_mask, 446 self_attn_rel_pos=self_attn_rel_pos_bias, 447 cross_attn_rel_pos=cross_attn_rel_pos_bias, 448 ) 449 l_aux.append(l_aux_i) 450 inner_states.append(x) File ~/.local/lib/python3.11/site-packages/torch/nn/modules/module.py:1538, in Module._call_impl(self, *args, **kwargs) 1535 bw_hook = hooks.BackwardHook(self, full_backward_hooks, backward_pre_hooks) 1536 args = bw_hook.setup_input_hook(args) -> 1538 result = forward_call(*args, **kwargs) 1539 if _global_forward_hooks or self._forward_hooks: 1540 for hook_id, hook in ( 1541 *_global_forward_hooks.items(), 1542 *self._forward_hooks.items(), 1543 ): File ~/.local/lib/python3.11/site-packages/torchscale/architecture/decoder.py:148, in DecoderLayer.forward(self, x, encoder_out, encoder_padding_mask, incremental_state, self_attn_mask, self_attn_padding_mask, self_attn_rel_pos, cross_attn_rel_pos) 145 if self.normalize_before: 146 x = self.self_attn_layer_norm(x) --> 148 x, attn = self.self_attn( 149 query=x, 150 key=x, 151 value=x, 152 key_padding_mask=self_attn_padding_mask, 153 incremental_state=incremental_state, 154 attn_mask=self_attn_mask, 155 rel_pos=self_attn_rel_pos, 156 ) 157 x = self.dropout_module(x) 159 if self.drop_path is not None: File ~/.local/lib/python3.11/site-packages/torch/nn/modules/module.py:1538, in Module._call_impl(self, *args, **kwargs) 1535 bw_hook = hooks.BackwardHook(self, full_backward_hooks, backward_pre_hooks) 1536 args = bw_hook.setup_input_hook(args) -> 1538 result = forward_call(*args, **kwargs) 1539 if _global_forward_hooks or self._forward_hooks: 1540 for hook_id, hook in ( 1541 *_global_forward_hooks.items(), 1542 *self._forward_hooks.items(), 1543 ): File ~/.local/lib/python3.11/site-packages/torchscale/component/multihead_attention.py:75, in MultiheadAttention.forward(self, query, key, value, incremental_state, key_padding_mask, attn_mask, rel_pos) 65 def forward( 66 self, 67 query, (...) 73 rel_pos=None, 74 ): ---> 75 bsz, tgt_len, embed_dim = query.size() 76 src_len = tgt_len 77 assert embed_dim == self.embed_dim, f"query dim {embed_dim} != {self.embed_dim}" ValueError: too many values to unpack (expected 3)
Then I tried to omit the batch_size dim:
input_size = (4, 1024) # (token_index, embedding_size)
I then had another error:
IndexError Traceback (most recent call last) Cell In[19], line 27 24 input_size = (4, 1024) # (token_index, embedding_size) 25 decoder = LatentDecoder(args=dec_config, is_encoder_decoder=True) ---> 27 summary(decoder, input_size=input_size) File ~/.local/lib/python3.11/site-packages/torchsummary/torchsummary.py:72, in summary(model, input_size, batch_size, device) 68 model.apply(register_hook) 70 # make a forward pass 71 # print(x.shape) ---> 72 model(*x) 74 # remove these hooks 75 for h in hooks: File ~/.local/lib/python3.11/site-packages/torch/nn/modules/module.py:1501, in Module._call_impl(self, *args, **kwargs) 1496 # If we don't have any hooks, we want to skip the rest of the logic in 1497 # this function, and just call forward. 1498 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks 1499 or _global_backward_pre_hooks or _global_backward_hooks 1500 or _global_forward_hooks or _global_forward_pre_hooks): -> 1501 return forward_call(*args, **kwargs) 1502 # Do not call functions when jit is used 1503 full_backward_hooks, non_full_backward_hooks = [], [] Cell In[19], line 11, in LatentDecoder.forward(self, x) 10 def forward(self, x): ---> 11 y = self.decoder(prev_output_tokens=torch.zeros((1,1)), token_embeddings=x, features_only=True) 12 return y File ~/.local/lib/python3.11/site-packages/torch/nn/modules/module.py:1538, in Module._call_impl(self, *args, **kwargs) 1535 bw_hook = hooks.BackwardHook(self, full_backward_hooks, backward_pre_hooks) 1536 args = bw_hook.setup_input_hook(args) -> 1538 result = forward_call(*args, **kwargs) 1539 if _global_forward_hooks or self._forward_hooks: 1540 for hook_id, hook in ( 1541 *_global_forward_hooks.items(), 1542 *self._forward_hooks.items(), 1543 ): File ~/.local/lib/python3.11/site-packages/torchscale/architecture/decoder.py:437, in Decoder.forward(self, prev_output_tokens, self_attn_padding_mask, encoder_out, incremental_state, features_only, return_all_hiddens, token_embeddings, **kwargs) 434 if idx not in incremental_state: 435 incremental_state[idx] = {} --> 437 x, layer_attn, _, l_aux_i = layer( 438 x, 439 encoder_out["encoder_out"] if encoder_out is not None else None, 440 encoder_out["encoder_padding_mask"] 441 if encoder_out is not None 442 else None, 443 incremental_state[idx] if incremental_state is not None else None, 444 self_attn_mask=self_attn_mask, 445 self_attn_padding_mask=self_attn_padding_mask, 446 self_attn_rel_pos=self_attn_rel_pos_bias, 447 cross_attn_rel_pos=cross_attn_rel_pos_bias, 448 ) 449 l_aux.append(l_aux_i) 450 inner_states.append(x) File ~/.local/lib/python3.11/site-packages/torch/nn/modules/module.py:1538, in Module._call_impl(self, *args, **kwargs) 1535 bw_hook = hooks.BackwardHook(self, full_backward_hooks, backward_pre_hooks) 1536 args = bw_hook.setup_input_hook(args) -> 1538 result = forward_call(*args, **kwargs) 1539 if _global_forward_hooks or self._forward_hooks: 1540 for hook_id, hook in ( 1541 *_global_forward_hooks.items(), 1542 *self._forward_hooks.items(), 1543 ): File ~/.local/lib/python3.11/site-packages/torchscale/architecture/decoder.py:148, in DecoderLayer.forward(self, x, encoder_out, encoder_padding_mask, incremental_state, self_attn_mask, self_attn_padding_mask, self_attn_rel_pos, cross_attn_rel_pos) 145 if self.normalize_before: 146 x = self.self_attn_layer_norm(x) --> 148 x, attn = self.self_attn( 149 query=x, 150 key=x, 151 value=x, 152 key_padding_mask=self_attn_padding_mask, 153 incremental_state=incremental_state, 154 attn_mask=self_attn_mask, 155 rel_pos=self_attn_rel_pos, 156 ) 157 x = self.dropout_module(x) 159 if self.drop_path is not None: File ~/.local/lib/python3.11/site-packages/torch/nn/modules/module.py:1547, in Module._call_impl(self, *args, **kwargs) 1545 hook_result = hook(self, args, kwargs, result) 1546 else: -> 1547 hook_result = hook(self, args, result) 1549 if hook_result is not None: 1550 result = hook_result File ~/.local/lib/python3.11/site-packages/torchsummary/torchsummary.py:19, in summary.<locals>.register_hook.<locals>.hook(module, input, output) 17 m_key = "%s-%i" % (class_name, module_idx + 1) 18 summary[m_key] = OrderedDict() ---> 19 summary[m_key]["input_shape"] = list(input[0].size()) 20 summary[m_key]["input_shape"][0] = batch_size 21 if isinstance(output, (list, tuple)): IndexError: tuple index out of range
Why it happens and how should I do to fix it?
Thanks!
It seems like a bug of torchsummary
Here are the versions of torchscale and torchsummary in my environment:
I am using my custom embedding to achieve an Auto-Regression task so I wrap the torchscale.architecture.decoder.Decoder with the following code:
When I tried to use torchsummary to get a summary of the model with these codes:
I got error:
Then I tried to omit the batch_size dim:
I then had another error:
Why it happens and how should I do to fix it?
Thanks!