2020-12-22 23:56:06,255 - root - INFO - start training language model
Traceback (most recent call last):
File "train_lm.py", line 90, in <module>
train.train()
File "train_lm.py", line 45, in train
self.runner.set_datasets(train, test)
File "/home/user/TensorflowASR/trainer/base_runners.py", line 164, in set_datasets
self.train_datasets=self.strategy.experimental_distribute_dataset(train)
File "/home/user/miniconda3/envs/tf/lib/python3.8/site-packages/tensorflow/python/distribute/distribute_lib.py", line 805, in experimental_distribute_dataset
return self._extended._experimental_distribute_dataset(dataset) # pylint: disable=protected-access
File "/home/user/miniconda3/envs/tf/lib/python3.8/site-packages/tensorflow/python/distribute/mirrored_strategy.py", line 638, in _experimental_distribute_dataset
return input_lib.get_distributed_dataset(
File "/home/user/miniconda3/envs/tf/lib/python3.8/site-packages/tensorflow/python/distribute/input_lib.py", line 84, in
get_distributed_dataset
return DistributedDataset(
File "/home/user/miniconda3/envs/tf/lib/python3.8/site-packages/tensorflow/python/distribute/input_lib.py", line 659, in __init__
with ops.colocate_with(dataset._variant_tensor):
AttributeError: 'generator' object has no attribute '_variant_tensor'
我尝试基于本项目训练语言模型。 修改了
configs/lm_data.yml
后,运行
python train_lm.py
,总是失败。下面是出错日志:我尝试基于 cpu和gpu的tf2.2.0,得到一样的错误日志。