Z-yq / TensorflowASR

一个执着于让CPU\端侧-Model逼近GPU-Model性能的项目,CPU上的实时率(RTF)小于0.1
Apache License 2.0
461 stars 111 forks source link

训练语言模型遇到问题 #22

Closed zhaoyukoon closed 3 years ago

zhaoyukoon commented 3 years ago

我尝试基于本项目训练语言模型。 修改了 configs/lm_data.yml

train_list: './common.all.1w'
eval_list: './common.all.1w'
...

bert:
  config_json: './LMmodel/bert/bert_config.json'
  bert_ckpt: './LMmodel/bert/bert_model.ckpt'
  bert_vocab: './LMmodel/bert/vocab.txt'

后,运行 python train_lm.py,总是失败。下面是出错日志:

2020-12-22 23:56:06,255 - root - INFO - start training language model
Traceback (most recent call last):
  File "train_lm.py", line 90, in <module>
    train.train()
  File "train_lm.py", line 45, in train
    self.runner.set_datasets(train, test)
  File "/home/user/TensorflowASR/trainer/base_runners.py", line 164, in set_datasets
    self.train_datasets=self.strategy.experimental_distribute_dataset(train)
  File "/home/user/miniconda3/envs/tf/lib/python3.8/site-packages/tensorflow/python/distribute/distribute_lib.py", line 805, in experimental_distribute_dataset
    return self._extended._experimental_distribute_dataset(dataset)  # pylint: disable=protected-access
  File "/home/user/miniconda3/envs/tf/lib/python3.8/site-packages/tensorflow/python/distribute/mirrored_strategy.py", line 638, in _experimental_distribute_dataset
    return input_lib.get_distributed_dataset(
  File "/home/user/miniconda3/envs/tf/lib/python3.8/site-packages/tensorflow/python/distribute/input_lib.py", line 84, in
get_distributed_dataset
    return DistributedDataset(
  File "/home/user/miniconda3/envs/tf/lib/python3.8/site-packages/tensorflow/python/distribute/input_lib.py", line 659, in __init__
    with ops.colocate_with(dataset._variant_tensor):
AttributeError: 'generator' object has no attribute '_variant_tensor'

我尝试基于 cpu和gpu的tf2.2.0,得到一样的错误日志。

Z-yq commented 3 years ago

已经修复

zhaoyukoon commented 3 years ago

验证了一下,问题是解决了。 遇到另外一个问题,我会开个新issue。