Open LianghuiGuo opened 7 months ago
I met the same problem
手动改了一下device,可以跑通。
def apply_rotary_pos_emb(q, k, cos, sin, position_ids):
# print(q.device, k.device, cos.device, sin.device, position_ids.device)
# cuda:0 cuda:0 cpu cpu cuda:0
# The first two dimensions of cos and sin are always 1, so we can `squeeze` them.
cos = cos.squeeze(1).squeeze(0) # [seq_len, dim]
sin = sin.squeeze(1).squeeze(0) # [seq_len, dim]
#将cos和sin的device对齐到position_ids
cos = cos.to(position_ids.device)
sin = sin.to(position_ids.device)
cos = cos[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim]
sin = sin[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim]
q_embed = (q * cos) + (rotate_half(q) * sin)
k_embed = (k * cos) + (rotate_half(k) * sin)
return q_embed, k_embed
mplug-owl2,finetune遇到这个问题,环境配置按照官方来的,数据用的32个测试数据,仿照LLAVA的数据构建