Open jouw opened 11 months ago
with open(os.path.join(self.data_dir, f'{self.data_type}.jsonl'), 'r') as f: for line in f: sample = json.loads(line) chat = sample['chat'] num_turns = int(sample['num_turns']) meta_instruction = sample['meta_instruction'] instruction_ids = self.tokenizer.encode(meta_instruction) assert isinstance(instruction_ids, list) and len(instruction_ids) > 0 input_ids = copy.deepcopy(instruction_ids) no_loss_spans = [(0, len(instruction_ids))] for i in range(num_turns): cur_turn_ids = [] cur_no_loss_spans = [] cur_turn = chat[f'turn_{i+1}'] for key, value in cur_turn.items(): cur_ids = self.tokenizer.encode(value) if key == 'Tool Responses': # The format tokens (<|Results|>:...<eor>\n) should have losses. cur_no_loss_spans.append((len(input_ids + cur_turn_ids) + 5, len(input_ids + cur_turn_ids + cur_ids) - 2)) assert isinstance(cur_ids, list) and len(cur_ids) > 0 cur_turn_ids.extend(cur_ids) if len(input_ids + cur_turn_ids) > 2048: break input_ids.extend(cur_turn_ids) no_loss_spans.extend(cur_no_loss_spans) if len(input_ids) == len(instruction_ids): continue assert len(input_ids) > 0 and len(input_ids) <= 2048 self.data.append(input_ids) self.no_loss_spans.append(no_loss_spans)