OpenMOSS / MOSS

An open-source tool-augmented conversational language model from Fudan University
https://txsun1997.github.io/blogs/moss.html
Apache License 2.0
11.89k stars 1.15k forks source link

多轮对话数据处理 #349

Open 447428054 opened 1 year ago

447428054 commented 1 year ago

@linonetwo @jsl9208 @xpqiu @meta-tabchen 您好

      with open(os.path.join(self.data_dir, f'{self.data_type}.jsonl'), 'r') as f:
                for line in f:
                    sample = json.loads(line)

                    chat = sample['chat']
                    num_turns = int(sample['num_turns'])

                    meta_instruction = sample['meta_instruction']
                    instruction_ids = self.tokenizer.encode(meta_instruction)
                    assert isinstance(instruction_ids, list) and len(instruction_ids) > 0

                    input_ids = copy.deepcopy(instruction_ids)
                    no_loss_spans = [(0, len(instruction_ids))]

                    for i in range(num_turns):
                        cur_turn_ids = []
                        cur_no_loss_spans = []
                        cur_turn = chat[f'turn_{i+1}']
                        for key, value in cur_turn.items():

                            cur_ids = self.tokenizer.encode(value)

                            if key == 'Tool Responses':
                                # The format tokens (<|Results|>:...<eor>\n) should have losses. 
                                cur_no_loss_spans.append((len(input_ids + cur_turn_ids) + 5, len(input_ids + cur_turn_ids + cur_ids) - 2))    

                            assert isinstance(cur_ids, list) and len(cur_ids) > 0

                            cur_turn_ids.extend(cur_ids)

                        if len(input_ids + cur_turn_ids) > 2048:
                            break

                        input_ids.extend(cur_turn_ids)
                        no_loss_spans.extend(cur_no_loss_spans)

                    if len(input_ids) == len(instruction_ids):
                        continue

                    assert len(input_ids) > 0 and len(input_ids) <= 2048

                    self.data.append(input_ids)
                    self.no_loss_spans.append(no_loss_spans)

请问这种多轮对话只放到一条样本能过加速训练过程吗? 我理解这种方式,减少了前向传播的次数,但是增加了反向传播的长度,而反向传播计算梯度要更耗时,这种方式会不会比每一轮单独抽出来训练要慢呢?

YanZiBuGuiCHunShiWan commented 11 months ago

我也好奇,还有请问为什么causal lm 不把多轮对话的human提问部分的损失忽略呢?这样不会有影响么?