showlab / Show-o

Repository for Show-o, One Single Transformer to Unify Multimodal Understanding and Generation.
https://arxiv.org/abs/2408.12528
Apache License 2.0
1.03k stars 44 forks source link

runtime error #10

Open junwenxiong opened 2 months ago

junwenxiong commented 2 months ago

Traceback (most recent call last): File "inference_t2i.py", line 372, in gen_token_ids = model.t2i_generate( File "Show-o/models/modeling_showo.py", line 108, in t2i_generate cond_logits, uncond_logits = self(model_input, attention_mask=attention_mask).chunk(2) File "python3.8/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl return forward_call(*args, kwargs) File "Show-o/models/modeling_showo.py", line 63, in forward logits = self.showo(input_ids=input_ids, attention_mask=attention_mask)['logits'] File "python3.8/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl return forward_call(*args, *kwargs) File "Show-o/models/phi.py", line 1191, in forward input_ids=input_ids, File "python3.8/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl return forward_call(args, kwargs) File "Show-o/models/phi.py", line 1068, in forward hidden_states, File "python3.8/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl return forward_call(*args, kwargs) File "Show-o/models/phi.py", line 800, in forward hidden_states=hidden_states, File "python3.8/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl return forward_call(*args, *kwargs) File "Show-o/models/phi.py", line 318, in forward query_states = self.q_layernorm(query_states) File "python3.8/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl return forward_call(args, kwargs) File "python3.8/site-packages/torch/nn/modules/normalization.py", line 190, in forward return F.layer_norm( File "python3.8/site-packages/torch/nn/functional.py", line 2515, in layer_norm return torch.layer_norm(input, normalized_shape, weight, bias, eps, torch.backends.cudnn.enabled) RuntimeError: Given normalized_shape=[64], expected input with shape [*, 64], but got input of size[2, 387, 2048]

An error occurs when inference is performed using the default configuration, and it seems to be a problem with the dimension of the features. Is there any solution for this?

Sierkinhane commented 2 months ago

Hi, sorry for the late reply. This error is caused by the position of qknorm layer. We use the class PhiSdpaAttention(PhiAttention) and we will update the corresponding code of class PhiAttention in phi.py. You can directly use the code below to change the position of qk norm.

query_states = self.q_proj(hidden_states)
key_states = self.k_proj(hidden_states)
value_states = self.v_proj(hidden_states)

#if self.qk_layernorm:
#    query_states = self.q_layernorm(query_states)
#    key_states = self.k_layernorm(key_states)

query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)

# use the qknorm here
if self.qk_layernorm:
    query_states = self.q_layernorm(query_states)
    key_states = self.k_layernorm(key_states)