if cur_input_ids.numel() > 0:
if getattr(self.config, 'tune_mm_mlp_adapter', False) and getattr(self.config, 'mm_use_im_start_end', False):
mask_idx = torch.nonzero(cur_input_ids==self.tokenizer.convert_tokens_to_ids(['<mask>'])[0])
_l = 0
for i, idx in enumerate(mask_idx):
cur_new_input_embeds.append(self.get_model().embed_tokens(cur_input_ids[_l:idx[0]]).detach())
## mask
cur_new_input_embeds.append(mask_feats[batch_idx][i:i+1].detach())
## pos
cur_new_input_embeds.append(pos_feats[batch_idx][i:i+1].detach())
if labels is not None:
cur_labels[idx[0]:idx[0]+2] = torch.full((2,), IGNORE_INDEX, device=labels.device, dtype=labels.dtype)
_l = idx[0]+2
if _l< len(cur_input_ids):
cur_new_input_embeds.append(self.get_model().embed_tokens(cur_input_ids[_l:]).detach())
else:
mask_idx = torch.nonzero(cur_input_ids==self.tokenizer.convert_tokens_to_ids(['<mask>'])[0])
assert len(mask_idx) == len(mask_feats[batch_idx]), "mask num not equal to mask feats"
_l = 0
for i, idx in enumerate(mask_idx):
cur_raw_new_input_embeds = self.get_model().embed_tokens(cur_input_ids[_l:idx[0]])
cur_new_input_embeds.append(cur_raw_new_input_embeds)
## mask
cur_new_input_embeds.append(mask_feats[batch_idx][i:i+1].to(cur_raw_new_input_embeds.dtype))
## pos
cur_new_input_embeds.append(pos_feats[batch_idx][i:i+1].to(cur_raw_new_input_embeds.dtype))
if labels is not None:
cur_labels[idx[0]:idx[0]+2] = torch.full((2,), IGNORE_INDEX, device=labels.device, dtype=labels.dtype)
_l = idx[0]+2
if _l< len(cur_input_ids):
cur_new_input_embeds.append(self.get_model().embed_tokens(cur_input_ids[_l:]))
if labels is not None:
cur_new_labels.append(cur_labels)
您好
请问你们在训练的时候,有没有遇到过训练卡在第一个epoch,但是GPU占用为100%的情况
一开始以为是服务器的问题,但只要把MASK Token部分代码删掉,就可以正常训练