taishan1994 / train_bert_use_your_data

基于pytorch使用自己的数据继续训练bert
9 stars 2 forks source link

基于pytorch使用自己的数据继续训练bert

步骤

首先需要去hugging face上下载相关bert的模型,比如chinese-bert-wwm-ext。
1、输入的训练文件。
输入训练的文件有很多中形式,这里以csv文件为例,在csv文件中需要保存一列名为text的数据。
2、定义参数并进行训练
定义好相关的参数,并基于一下命令训练自己的数据:

python run_mlm_no_trainer.py --train_file "./train.csv" --model_name_or_path "../model_hub/chinese-bert-wwm-ext" --output_dir ./tmp/ --num_train_epochs 1 --max_seq_length 256 --preprocessing_num_workers 4

3、训练好的模型会保存在model_name_or_path下,里面不包含vocab.txt,需要将原始模型的vocab.txt拷贝过来。
4、使用自己训练好的模型进行预测。
在test_model.py中,修改相关模型位置以及输入,比如:
输入:text = '中国保险资产管理业协会积极推进保险私[MASK]基金登记制改革落地实施'
输出:募
5、训练时的输出结果:

Model config BertConfig {
  "architectures": [
    "BertForMaskedLM"
  ],
  "attention_probs_dropout_prob": 0.1,
  "classifier_dropout": null,
  "directionality": "bidi",
  "hidden_act": "gelu",
  "hidden_dropout_prob": 0.1,
  "hidden_size": 768,
  "initializer_range": 0.02,
  "intermediate_size": 3072,
  "layer_norm_eps": 1e-12,
  "max_position_embeddings": 512,
  "model_type": "bert",
  "num_attention_heads": 12,
  "num_hidden_layers": 12,
  "output_past": true,
  "pad_token_id": 0,
  "pooler_fc_size": 768,
  "pooler_num_attention_heads": 12,
  "pooler_num_fc_layers": 3,
  "pooler_size_per_head": 128,
  "pooler_type": "first_token_transform",
  "position_embedding_type": "absolute",
  "transformers_version": "4.11.2",
  "type_vocab_size": 2,
  "use_cache": true,
  "vocab_size": 21128
}
10/01/2021 10:16:01 - INFO - __main__ - Sample 11195 of the training set: {'attention_mask': [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], 'input_ids': [8024, 2779, 5635, 8119, 2399, 510, 8112, 2399, 1350, 8109, 2399, 8110, 3299, 8176, 3189, 1146, 1166, 3221, 9246, 8157, 110, 510, 9910, 8129, 119, 123, 110, 510, 11863, 119, 126, 110, 8039, 3291, 6375, 782, 904, 4680, 4638, 3221, 8024, 2779, 5635, 8271, 2399, 127, 3299, 8114, 3189, 8024, 6821, 671, 3144, 2099, 2347, 678, 7360, 5635, 8252, 119, 127, 110, 511, 6566, 965, 4999, 7360, 2533, 1963, 3634, 100, 2571, 100, 1469, 100, 1114, 100, 8024, 738, 2471, 6629, 2356, 1767, 4638, 7770, 2428, 1068, 3800, 511, 704, 3448, 3175, 7481, 1762, 2875, 5500, 741, 704, 5314, 1139, 4638, 6237, 7025, 3221, 8024, 6814, 1343, 1126, 2399, 1112, 6566, 965, 4372, 6772, 7770, 3221, 1728, 711, 712, 6206, 2418, 2190, 689, 1218, 1872, 7270, 1071, 6084, 6598, 7444, 3724, 679, 3171, 1217, 2487, 809, 3119, 6579, 7583, 1912, 1765, 1779, 8024, 1398, 3198, 4685, 1068, 4289, 689, 7555, 4680, 2213, 3313, 2458, 1993, 7564, 1545, 1350, 772, 4495, 4385, 7032, 3837, 1057, 511, 4764, 3309, 1079, 1112, 6566, 965, 4372, 678, 7360, 5635, 8252, 119, 127, 110, 8024, 3221, 2600, 3326, 4660, 2810, 1920, 1469, 924, 4522, 4659, 1164, 511, 2940, 1368, 6413, 6432, 8024, 704, 3448, 3418, 2945, 6121, 689, 2501, 1232, 6822, 6121, 749, 712, 1220, 7360, 6566, 965, 8024, 5445, 684, 2768, 5327, 3227, 5865, 511, 704, 3448, 4955, 4994, 3221, 1963, 862, 1872, 3119, 1121, 3118, 4638, 8043, 122, 510, 1872, 3119, 8038, 2875, 5500, 741, 3227, 4850, 8024, 704, 3448, 2971, 5500, 8119, 2399, 5635, 8109, 2399, 3119, 4660], 'special_tokens_mask': [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], 'token_type_ids': [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]}.
10/01/2021 10:16:01 - INFO - __main__ - Sample 20479 of the training set: {'attention_mask': [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], 'input_ids': [782, 2798, 1824, 1075, 511, 3315, 3613, 1920, 6612, 809, 107, 1916, 679, 1398, 8024, 3341, 2773, 8013, 107, 711, 712, 7579, 8024, 3192, 1762, 7961, 1225, 1158, 2692, 1158, 3173, 8024, 711, 2408, 1920, 831, 4899, 2110, 2094, 2990, 897, 2245, 4850, 2798, 1290, 510, 683, 689, 769, 3837, 1350, 2110, 739, 4638, 2398, 1378, 511, 5296, 678, 2146, 6382, 1350, 2590, 775, 833, 2199, 6818, 6655, 4895, 792, 5305, 3315, 3613, 1920, 6612, 778, 4157, 8024, 2400, 711, 2110, 2094, 5031, 4542, 6237, 2663, 8024, 6858, 6814, 1649, 2161, 4638, 4385, 1767, 769, 3837, 680, 757, 1220, 8024, 2372, 5314, 2110, 2094, 3291, 1914, 1423, 1355, 511, 3680, 702, 1814, 2356, 5296, 678, 3833, 1220, 1146, 2146, 6382, 833, 1469, 2590, 775, 833, 697, 1767, 6822, 6121, 8024, 2146, 6382, 833, 712, 6206, 1259, 2886, 8038, 792, 5305, 1920, 6612, 8024, 1423, 1220, 811, 2466, 8024, 757, 1220, 5031, 4542, 5023, 3833, 1220, 511, 2590, 775, 833, 1156, 6913, 6435, 1649, 2161, 6822, 6121, 757, 1220, 1146, 775, 8024, 1649, 2161, 2199, 6913, 6435, 6121, 689, 1920, 1476, 8038, 4289, 4130, 4906, 2825, 7674, 2375, 1158, 2692, 2135, 6948, 2361, 756, 1044, 4495, 8024, 3918, 5449, 2408, 1440, 5852, 7218, 680, 3144, 2099, 1158, 2692, 1282, 676, 2399, 8039, 704, 1744, 2825, 3318, 1158, 689, 1291, 833, 1352, 1158, 683, 2157, 1999, 1447, 833, 1199, 712, 818, 1999, 1447, 6145, 4661, 1044, 4495, 8024, 683, 689, 794, 752, 1158, 689, 3302, 1218, 1469, 821, 689, 2118, 1265, 8039, 1920, 6125, 5381, 8371, 4374, 4899], 'special_tokens_mask': [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], 'token_type_ids': [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]}.
10/01/2021 10:16:01 - INFO - __main__ - Sample 50945 of the training set: {'attention_mask': [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], 'input_ids': [4850, 8038, 100, 2398, 2128, 1057, 5500, 1400, 8024, 2372, 3341, 3173, 4638, 5052, 4415, 1730, 7339, 833, 1762, 791, 2399, 2458, 2245, 3173, 4638, 689, 1218, 8024, 1728, 711, 1290, 1909, 3315, 6716, 2218, 3221, 7028, 6598, 772, 4289, 689, 6772, 1914, 8024, 2792, 809, 1184, 3309, 1762, 3173, 689, 1218, 2458, 2245, 3175, 7481, 8024, 712, 6206, 809, 6768, 6598, 772, 5052, 4415, 6783, 1139, 711, 712, 511, 100, 1343, 2399, 130, 3299, 819, 8024, 1290, 1909, 2401, 4886, 680, 2398, 2128, 7415, 1730, 5041, 5392, 2773, 4526, 1394, 868, 1291, 6379, 8024, 1352, 3175, 2199, 1762, 7270, 4909, 1062, 2171, 510, 2434, 1075, 3302, 1218, 510, 3749, 6756, 3302, 1218, 772, 689, 7415, 5408, 510, 3255, 2716, 1814, 2356, 5023, 4052, 1762, 1355, 2245, 3175, 1403, 677, 3918, 1057, 2968, 6374, 1469, 4777, 4955, 2458, 2245, 1392, 5102, 1798, 1394, 868, 511, 966, 2533, 3800, 2692, 4638, 3221, 8024, 2398, 2128, 2824, 6437, 679, 712, 1220, 6450, 3724, 2190, 677, 2356, 1062, 1385, 4638, 2971, 1169, 3326, 8024, 2792, 809, 1993, 5303, 1290, 1909, 2971, 5500, 5018, 671, 6121, 1220, 782, 1762, 677, 2356, 1062, 1385, 5500, 3326, 3683, 891, 7770, 6809, 8216, 110, 809, 677, 511, 9160, 2399, 123, 3299, 8024, 1290, 3883, 5390, 1765, 1426, 1403, 691, 4638, 1217, 4673, 8024, 1059, 7481, 2972, 6822, 1290, 1909, 2401, 4886, 3173, 689, 1218, 8024, 6821, 738, 3221, 1426, 1403, 691, 3078, 7270, 4638, 7566, 1818, 8024, 1555, 1215, 7555, 4680, 4638, 2458, 1355, 680, 6817, 5852, 511, 1290, 1909, 2401, 4886, 6392], 'special_tokens_mask': [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], 'token_type_ids': [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]}.
10/01/2021 10:16:10 - INFO - __main__ - ***** Running training *****
10/01/2021 10:16:10 - INFO - __main__ -   Num examples = 86793
10/01/2021 10:16:10 - INFO - __main__ -   Num Epochs = 1
10/01/2021 10:16:10 - INFO - __main__ -   Instantaneous batch size per device = 8
10/01/2021 10:16:10 - INFO - __main__ -   Total train batch size (w. parallel, distributed & accumulation) = 8
10/01/2021 10:16:10 - INFO - __main__ -   Gradient Accumulation steps = 1
10/01/2021 10:16:10 - INFO - __main__ -   Total optimization steps = 10850
100% 10850/10850 [2:38:29<00:00,  1.48it/s]10/01/2021 12:57:28 - INFO - __main__ - epoch 0: perplexity: 2.318695606245697
Configuration saved in ./tmp/config.json
Model weights saved in ./tmp/pytorch_model.bin
100% 10850/10850 [2:41:19<00:00,  1.12it/s]

需要注意的是自己设置的max_seq_len,要记住是多少。