OpenMOSS / MOSS

An open-source tool-augmented conversational language model from Fudan University
https://txsun1997.github.io/blogs/moss.html
Apache License 2.0
11.89k stars 1.15k forks source link

finetune的时候为何没有把<|Human|>的loss给mask掉? #359

Open jouw opened 11 months ago

jouw commented 11 months ago
        with open(os.path.join(self.data_dir, f'{self.data_type}.jsonl'), 'r') as f:
            for line in f:
                sample = json.loads(line)

                chat = sample['chat']
                num_turns = int(sample['num_turns'])

                meta_instruction = sample['meta_instruction']
                instruction_ids = self.tokenizer.encode(meta_instruction)
                assert isinstance(instruction_ids, list) and len(instruction_ids) > 0

                input_ids = copy.deepcopy(instruction_ids)
                no_loss_spans = [(0, len(instruction_ids))]

                for i in range(num_turns):
                    cur_turn_ids = []
                    cur_no_loss_spans = []
                    cur_turn = chat[f'turn_{i+1}']
                    for key, value in cur_turn.items():

                        cur_ids = self.tokenizer.encode(value)

                        if key == 'Tool Responses':
                            # The format tokens (<|Results|>:...<eor>\n) should have losses. 
                            cur_no_loss_spans.append((len(input_ids + cur_turn_ids) + 5, len(input_ids + cur_turn_ids + cur_ids) - 2))    

                        assert isinstance(cur_ids, list) and len(cur_ids) > 0

                        cur_turn_ids.extend(cur_ids)

                    if len(input_ids + cur_turn_ids) > 2048:
                        break

                    input_ids.extend(cur_turn_ids)
                    no_loss_spans.extend(cur_no_loss_spans)

                if len(input_ids) == len(instruction_ids):
                    continue

                assert len(input_ids) > 0 and len(input_ids) <= 2048

                self.data.append(input_ids)
                self.no_loss_spans.append(no_loss_spans)