FudanDISC / DISC-MedLLM

Repository of DISC-MedLLM, it is a comprehensive solution that leverages Large Language Models (LLMs) to provide accurate and truthful medical response in end-to-end conversational healthcare services.
Apache License 2.0
485 stars 44 forks source link

运行出错 #4

Open loki1017 opened 1 year ago

loki1017 commented 1 year ago

特别感谢您的无私贡献,我使用baichaunChat进行qlora微调的时候出现了keyerror的问题,具体如下:

╭─────────────────────────────── Traceback (most recent call last) ────────────────────────────────╮
│ /Work/disc/train/train.py:185 in <module>                                        │
│                                                                                                  │
│   182                                                                                            │
│   183                                                                                            │
│   184 if __name__ == "__main__":                                                                 │
│ ❱ 185 │   main()                                                                                 │
│   186                                                                                            │
│                                                                                                  │
│ /Work/disc/train/train.py:173 in main                                            │
│                                                                                                  │
│   170 │   trainer = init_components(args, training_args)                                         │
│   171 │   # 开始训练                                                                             │
│   172 │   logger.info("*** starting training ***")                                               │
│ ❱ 173 │   train_result = trainer.train(resume_from_checkpoint=training_args.resume_from_checkp   │
│   174 │   # 保存最好的checkpoint                                                                 │
│   175 │   final_save_path = join(training_args.output_dir, 'final')                              │
│   176 │   trainer.save_model(final_save_path)  # Saves the tokenizer too                         │
│                                                                                                  │
│ /home/anaconda3/envs/lib/python3.10/site-packages/transformers/trainer.py │
│ :1645 in train                                                                                   │
│                                                                                                  │
│   1642 │   │   inner_training_loop = find_executable_batch_size(                                 │
│   1643 │   │   │   self._inner_training_loop, self._train_batch_size, args.auto_find_batch_size  │
│   1644 │   │   )                                                                                 │
│ ❱ 1645 │   │   return inner_training_loop(                                                       │
│   1646 │   │   │   args=args,                                                                    │
│   1647 │   │   │   resume_from_checkpoint=resume_from_checkpoint,                                │
│   1648 │   │   │   trial=trial,                                                                  │
│                                                                                                  │
│ /home/anaconda3/envs/lib/python3.10/site-packages/transformers/trainer.py │
│ :1916 in _inner_training_loop                                                                    │
│                                                                                                  │
│   1913 │   │   │   │   rng_to_sync = True                                                        │
│   1914 │   │   │                                                                                 │
│   1915 │   │   │   step = -1                                                                     │
│ ❱ 1916 │   │   │   for step, inputs in enumerate(epoch_iterator):                                │
│   1917 │   │   │   │   total_batched_samples += 1                                                │
│   1918 │   │   │   │   if rng_to_sync:                                                           │
│   1919 │   │   │   │   │   self._load_rng_state(resume_from_checkpoint)                          │
│                                                                                                  │
│ /home/anaconda3/envs/lib/python3.10/site-packages/torch/utils/data/datalo │
│ ader.py:633 in __next__                                                                          │
│                                                                                                  │
│    630 │   │   │   if self._sampler_iter is None:                                                │
│    631 │   │   │   │   # TODO(https://github.com/pytorch/pytorch/issues/76750)                   │
│    632 │   │   │   │   self._reset()  # type: ignore[call-arg]                                   │
│ ❱  633 │   │   │   data = self._next_data()                                                      │
│    634 │   │   │   self._num_yielded += 1                                                        │
│    635 │   │   │   if self._dataset_kind == _DatasetKind.Iterable and \                          │
│    636 │   │   │   │   │   self._IterableDataset_len_called is not None and \                    │
│                                                                                                  │
│ /home/anaconda3/envs/lib/python3.10/site-packages/torch/utils/data/datalo │
│ ader.py:1345 in _next_data                                                                       │
│                                                                                                  │
│   1342 │   │   │   │   self._task_info[idx] += (data,)                                           │
│   1343 │   │   │   else:                                                                         │
│   1344 │   │   │   │   del self._task_info[idx]                                                  │
│ ❱ 1345 │   │   │   │   return self._process_data(data)                                           │
│   1346 │                                                                                         │
│   1347 │   def _try_put_index(self):                                                             │
│   1348 │   │   assert self._tasks_outstanding < self._prefetch_factor * self._num_workers        │
│                                                                                                  │
│ /home/anaconda3/envs/lib/python3.10/site-packages/torch/utils/data/datalo │
│ ader.py:1371 in _process_data                                                                    │
│                                                                                                  │
│   1368 │   │   self._rcvd_idx += 1                                                               │
│   1369 │   │   self._try_put_index()                                                             │
│   1370 │   │   if isinstance(data, ExceptionWrapper):                                            │
│ ❱ 1371 │   │   │   data.reraise()                                                                │
│   1372 │   │   return data                                                                       │
│   1373 │                                                                                         │
│   1374 │   def _mark_worker_as_unavailable(self, worker_id, shutdown=False):                     │
│                                                                                                  │
│ /home/anaconda3/envs/lib/python3.10/site-packages/torch/_utils.py:644 in  │
│ reraise                                                                                          │
│                                                                                                  │
│   641 │   │   │   # If the exception takes multiple arguments, don't try to                      │
│   642 │   │   │   # instantiate since we don't know how to                                       │
│   643 │   │   │   raise RuntimeError(msg) from None                                              │
│ ❱ 644 │   │   raise exception                                                                    │
│   645                                                                                            │
│   646                                                                                            │
│   647 def _get_available_device_type():                                                          │
╰──────────────────────────────────────────────────────────────────────────────────────────────────╯
KeyError: Caught KeyError in DataLoader worker process 0.
Original Traceback (most recent call last):
  File "/home/anaconda3/envs/lib/python3.10/site-packages/torch/utils/data/_utils/worker.py", line 308, in _worker_loop
    data = fetcher.fetch(index)
  File "/home/anaconda3/envs/lib/python3.10/site-packages/torch/utils/data/_utils/fetch.py", line 54, in fetch
    return self.collate_fn(data)
  File "/home/anaconda3/envs/lib/python3.10/site-packages/transformers/trainer_utils.py", line 706, in __call__
    return self.data_collator(features)
  File "/Work/disc/train/component/collator.py", line 23, in __call__
    target_mask = x['target_mask']
KeyError: 'target_mask'

参数详情:

gpu_vis=4,5
MASTER_PORT=1942
deepspeed  --include localhost:$gpu_vis --master_port $MASTER_PORT disc/train/train.py \
    --deepspeed disc/train/train_args/ds_z3_config.json \
    --output_dir disc/out \
    --model_name_or_path pre_model/Baichuan-13B-Chat-v2 \
    --train_file disc/train/data/DISC-Med-SFT_released.jsonl \
    --overwrite_cache \
    --overwrite_output_dir \
    --num_train_epochs 1.0 \
    --per_device_train_batch_size 4 \
    --gradient_accumulation_steps 3 \
    --learning_rate 1e-5 \
    --max_seq_length 1200 \
    --logging_steps 50 \
    --save_steps 2000 \
    --save_total_limit 3 \
    --lr_scheduler_type cosine \
    --warmup_steps 800 \
    --gradient_checkpointing false \
    --disable_tqdm false \
    --optim adamw_hf \
    --seed 42 \
    --fp16 false \
    --bf16 true \
    --report_to tensorboard \
    --dataloader_num_workers 5  \
    --save_strategy steps \
    --weight_decay 0 \
    --max_grad_norm 1.0 \
    --quantization_bit 4

我发现在transformers包中trainer_utils.py文件中的call()方法,会将target_mask属性给移除掉,具体如下:

 def _remove_columns(self, feature: dict) -> dict:
        if not isinstance(feature, dict):
            return feature
        if not self.message_logged and self.logger and self.model_name:
            ignored_columns = list(set(feature.keys()) - set(self.signature_columns))
            if len(ignored_columns) > 0:
                dset_description = "" if self.description is None else f"in the {self.description} set"
                self.logger.info(
                    f"The following columns {dset_description} don't have a corresponding argument in "
                    f"`{self.model_name}.forward` and have been ignored: {', '.join(ignored_columns)}."
                    f" If {', '.join(ignored_columns)} are not expected by `{self.model_name}.forward`, "
                    " you can safely ignore this message."
                )
                self.message_logged = True
        return {k: v for k, v in feature.items() if k in self.signature_columns}
def __call__(self, features: List[dict]):
    features = [self._remove_columns(feature) for feature in features]
    return self.data_collator(features)

我对源文件进行了修改,注释掉了features = [self._remove_columns(feature) for feature in features],但是却发生其他错误。因此,我想知到您的transformers和transformers_stream_generator的具体版本是多少,又或者是代码逻辑哪里有什么疏漏,万分感谢!!!(我的版本:transformers== 4.30.1 ,transformers-stream-generator ==0.0.4

F1mc commented 1 year ago
transformers==4.31.0
transformers-stream-generator==0.0.4

我们的训练代码修改自Firefly仓库,他们训练的时候使用了不一样的损失计算mask机制: 正常来说一段对话发生了三轮,在训练时可以分为三个样本,每个样本只有当前目标回复作为target计算损失 而Firefly将整段对话全部给到模型并同时将三轮的回复作为target计算损失以提高效率

我猜测你的问题来自于这里的特殊处理,对于你当前的情况,我建议你尝试使用'Firefly'仓库中QLoRA脚本,在此基础上进行调试尝试,可能可以更快解决你的问题 我们的改动集中于component/dataset.py,以及一些针对transformers等pkg版本适配问题的改动,应当不怎么影响与原仓库QLoRA微调脚本的结合使用

loki1017 commented 1 year ago

好的,非常感谢,我去尝试修改下