shibing624 / MedicalGPT

MedicalGPT: Training Your Own Medical GPT Model with ChatGPT Training Pipeline. 训练医疗大模型,实现了包括增量预训练(PT)、有监督微调(SFT)、RLHF、DPO、ORPO。
Apache License 2.0
3.18k stars 483 forks source link

大量数据加载问题 #396

Closed dage0127 closed 1 month ago

dage0127 commented 1 month ago

请教一下: 训练加载的数据量比较大,300万行左右,大约需要3小时以上。加载500万行,等了差不多一晚上了,还没有完成。 最终可能需要加载超过1000万的数据进行训练,担心加载不进去。

有没有什么好的方法,一边加载一边训练,避免一次加载全量的数据,导致数据量过大,导致内存溢出。

另外,我看代码里面有流式的参数,从代码里面没有看出有太多的差别,只是少了并行处理“num_proc=data_args.preprocessing_num_workers,” 还是需要提前加载近全量数据,Tokenization,然后再训练。

with training_args.main_process_first(desc="Dataset tokenization and grouping"): if not data_args.streaming: if training_args.group_by_length: tokenized_datasets = raw_datasets.map( tokenize_wo_pad_function, batched=True, num_proc=data_args.preprocessing_num_workers, remove_columns=column_names, load_from_cache_file=not data_args.overwrite_cache, desc="Running tokenizer on dataset", ) lm_datasets = tokenized_datasets.map( group_text_function, batched=True, num_proc=data_args.preprocessing_num_workers, load_from_cache_file=not data_args.overwrite_cache, desc=f"Grouping texts in chunks of {block_size}", ) else: lm_datasets = raw_datasets.map( tokenize_function, batched=True, num_proc=data_args.preprocessing_num_workers, remove_columns=column_names, load_from_cache_file=not data_args.overwrite_cache, desc="Running tokenizer on dataset", ) else: if training_args.group_by_length: tokenized_datasets = raw_datasets.map( tokenize_wo_pad_function, batched=True, remove_columns=column_names, ) lm_datasets = tokenized_datasets.map( group_text_function, batched=True, ) else: lm_datasets = raw_datasets.map( tokenize_function, batched=True, remove_columns=column_names, )

shibing624 commented 1 month ago

用流式streaming

dage0127 commented 1 month ago

谢谢了。