wzzzd / pretrain_bert_with_maskLM

使用Mask LM预训练任务来预训练Bert模型。训练垂直领域语料的模型表征,提升下游任务的表现。
41 stars 11 forks source link

Pretrain_Bert_with_MaskLM

Info

使用Mask LM预训练任务来预训练Bert模型。

基于pytorch框架,训练关于垂直领域语料的预训练语言模型,目的是提升下游任务的表现。

Pretraining Task

Mask Language Model,简称Mask LM,即基于Mask机制的预训练语言模型。

同时支持 原生的MaskLM任务和Whole Words Masking任务。默认使用Whole Words Masking

MaskLM

使用来自于Bert的mask机制,即对于每一个句子中的词(token):

Whole Words Masking

与MaskLM类似,但是在mask的步骤有些少不同。

在Bert类模型中,考虑到如果单独使用整个词作为词表的话,那词表就太大了。不利于模型对同类词的不同变种的特征学习,故采用了WordPiece的方式进行分词。

Whole Words Masking的方法在于,在进行mask操作时,对象变为分词前的整个词,而非子词。

Model

使用原生的Bert模型作为基准模型。

Datasets

项目里的数据集来自wikitext,分成两个文件训练集(train.txt)和测试集(test.txt)。

数据以行为单位存储。

若想要替换成自己的数据集,可以使用自己的数据集进行替换。(注意:如果是预训练中文模型,需要修改配置文件Config.py中的self.initial_pretrain_modelself.initial_pretrain_tokenizer,将值修改成 bert-base-chinese

自己的数据集不需要做mask机制处理,代码会处理。

Training Target

本项目目的在于基于现有的预训练模型参数,如google开源的bert-base-uncasedbert-base-chinese等,在垂直领域的数据语料上,再次进行预训练任务,由此提升bert的模型表征能力,换句话说,也就是提升下游任务的表现。

Environment

项目主要使用了Huggingface的datasetstransformers模块,支持CPU、单卡单机、单机多卡三种模式。

python的版本为: 3.8

可通过以下命令安装依赖包

    pip install -r requirement.txt

主要包含的模块如下:

    numpy==1.24.1
    pandas==1.5.2
    scikit_learn==1.2.1
    torch==1.8.0
    tqdm==4.64.1
    transformers==4.26.1

Get Start

单卡模式

(1) 训练

直接运行

    python main.py

(2) 测试

修改Config.py文件中的self.mode='test',再运行

    python main.py

多卡模式(训练)

如果你足够幸运,拥有了多张GPU卡,那么恭喜你,你可以进入起飞模式。🚀🚀

(1)使用torch的nn.parallel.DistributedDataParallel模块进行多卡训练。其中config.py文件中参数如下,默认可以不用修改。

(2)运行程序启动命令

    chmod 755 run.sh
    ./run.sh

Experiment

training

使用交叉熵(cross-entropy)作为损失函数,困惑度(perplexity)和Loss作为评价指标来进行训练,训练过程如下:

test

结果保存在dataset/output/pred_data.csv,分别包含三列:

示例

src:  [CLS] art education and first professional work [SEP]
pred: [CLS] art education and first class work [SEP]
mask: [CLS] art education and first [MASK] work [SEP] [PAD] [PAD] [PAD] ...

Reference

【Bert】https://arxiv.org/pdf/1810.04805.pdf

【transformers】https://github.com/huggingface/transformers

【datasets】https://huggingface.co/docs/datasets/quicktour.html