CircleRadon / Osprey

[CVPR2024] The code for "Osprey: Pixel Understanding with Visual Instruction Tuning"
Apache License 2.0
730 stars 42 forks source link

关于模型训练 #24

Open chencn2020 opened 4 months ago

chencn2020 commented 4 months ago

您好

请问你们在训练的时候,有没有遇到过训练卡在第一个epoch,但是GPU占用为100%的情况

一开始以为是服务器的问题,但只要把MASK Token部分代码删掉,就可以正常训练

if cur_input_ids.numel() > 0:
                if getattr(self.config, 'tune_mm_mlp_adapter', False) and getattr(self.config, 'mm_use_im_start_end', False):
                    mask_idx = torch.nonzero(cur_input_ids==self.tokenizer.convert_tokens_to_ids(['<mask>'])[0])
                    _l = 0
                    for i, idx in enumerate(mask_idx):
                        cur_new_input_embeds.append(self.get_model().embed_tokens(cur_input_ids[_l:idx[0]]).detach())
                        ## mask
                        cur_new_input_embeds.append(mask_feats[batch_idx][i:i+1].detach())
                        ## pos
                        cur_new_input_embeds.append(pos_feats[batch_idx][i:i+1].detach())
                        if labels is not None:
                            cur_labels[idx[0]:idx[0]+2] = torch.full((2,), IGNORE_INDEX, device=labels.device, dtype=labels.dtype)
                        _l = idx[0]+2
                    if _l< len(cur_input_ids):
                        cur_new_input_embeds.append(self.get_model().embed_tokens(cur_input_ids[_l:]).detach())

                else:
                    mask_idx = torch.nonzero(cur_input_ids==self.tokenizer.convert_tokens_to_ids(['<mask>'])[0])
                    assert len(mask_idx) == len(mask_feats[batch_idx]), "mask num not equal to mask feats"

                    _l = 0
                    for i, idx in enumerate(mask_idx):
                        cur_raw_new_input_embeds = self.get_model().embed_tokens(cur_input_ids[_l:idx[0]])
                        cur_new_input_embeds.append(cur_raw_new_input_embeds)
                        ## mask
                        cur_new_input_embeds.append(mask_feats[batch_idx][i:i+1].to(cur_raw_new_input_embeds.dtype))
                        ## pos
                        cur_new_input_embeds.append(pos_feats[batch_idx][i:i+1].to(cur_raw_new_input_embeds.dtype))

                        if labels is not None:
                            cur_labels[idx[0]:idx[0]+2] = torch.full((2,), IGNORE_INDEX, device=labels.device, dtype=labels.dtype)

                        _l = idx[0]+2
                    if _l< len(cur_input_ids):
                        cur_new_input_embeds.append(self.get_model().embed_tokens(cur_input_ids[_l:]))

                if labels is not None:
                    cur_new_labels.append(cur_labels)
LiWentomng commented 4 months ago

您好@chencn2020 我们也遇到过类似的情况,这通常发生在训练数据量很大的情况下。我们猜测可能和服务器的性能限制有关。 如果您有什么好的想法,欢迎进一步交流或者提pr。