ZhuiyiTechnology / simbert

a bert for retrieval and generation
Apache License 2.0
840 stars 152 forks source link

生成句子报错,维度不一致。 #24

Closed wjx-git closed 2 years ago

wjx-git commented 2 years ago

使用 simbert.py 训练得到模型,使用 gen_synonyms_test.py 测试时报错: tensorflow.python.framework.errors_impl.InvalidArgumentError: Input to reshape is a tensor with 13584 values, but the requested shape has 13685

wjx-git commented 2 years ago

解决方法: 训练模型时将 keep_tokens 注释即可。 `bert = build_transformer_model( config_path, checkpoint_path, with_pool='linear', application='unilm',

keep_tokens=keep_tokens, # 只保留keep_tokens中的字,精简原字表

    return_keras_model=False,
)`