valuesimplex / FinBERT

Apache License 2.0
677 stars 110 forks source link

pytorch加载模型出 bug,已经把TF版本改为 1.14 #6

Open Hyacintheater opened 3 years ago

Hyacintheater commented 3 years ago

model = BertForSequenceClassification.from_pretrained('/content/sample_data/FinBERT_pytorch/bert_config.json',from_tf = True) 以下是报错:

AttributeError Traceback (most recent call last)

in () ----> 1 model = BertForSequenceClassification.from_pretrained('/content/sample_data/FinBERT_pytorch/bert_config.json',from_tf = True) /usr/local/lib/python3.6/dist-packages/transformers/modeling_utils.py in from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs) 970 from .modeling_tf_pytorch_utils import load_tf2_checkpoint_in_pytorch_model 971 --> 972 model = load_tf2_checkpoint_in_pytorch_model(model, resolved_archive_file, allow_missing_keys=True) 973 except ImportError: 974 logger.error( /usr/local/lib/python3.6/dist-packages/transformers/modeling_tf_pytorch_utils.py in load_tf2_checkpoint_in_pytorch_model(pt_model, tf_checkpoint_path, tf_inputs, allow_missing_keys) 266 import transformers 267 --> 268 from .modeling_tf_utils import load_tf_weights 269 270 logger.info("Loading TensorFlow weights from {}".format(tf_checkpoint_path)) /usr/local/lib/python3.6/dist-packages/transformers/modeling_tf_utils.py in () 1029 1030 -> 1031 def get_initializer(initializer_range: float = 0.02) -> tf.initializers.TruncatedNormal: 1032 """ 1033 Creates a :obj:`tf.initializers.TruncatedNormal` with the given range. /usr/local/lib/python3.6/dist-packages/tensorflow/python/util/deprecation_wrapper.py in __getattr__(self, name) 104 if name.startswith('_dw_'): 105 raise AttributeError('Accessing local variables before they are created.') --> 106 attr = getattr(self._dw_wrapped_module, name) 107 if (self._dw_warning_count < _PER_MODULE_WARNING_LIMIT and 108 name not in self._dw_deprecated_printed): AttributeError: module 'tensorflow._api.v1.initializers' has no attribute 'TruncatedNormal' 请问依赖的 transformers 的版本是什么? 另外,tokenizer = BertTokenizer.from_pretrained('/content/sample_data/FinBERT_pytorch',from_tf=True),这行命令是可以运行的
houpanpan commented 3 years ago

您好,推荐使用google官方的代码加载模型,正常应该是没有问题的,依赖的transformers的版本为即为google官方公布的模型依赖的版本。