RUCKBReasoning / RESDSQL

The Pytorch implementation of RESDSQL (AAAI 2023).
https://arxiv.org/abs/2302.05965
MIT License
245 stars 58 forks source link

xlm_roberta_text2natsql_schema_item_classifier #70

Open zapython opened 8 months ago

zapython commented 8 months ago

xlm_roberta_text2natsql_schema_item_classifier 为什么不能使用 报错 OSError: Error no file named pytorch_model.bin, tf_model.h5, model.ckpt.index or flax_model.msgpack found 怎样解决

lihaoyang-ruc commented 7 months ago

这有可能是因为下载的压缩包出现了丢包的情况,请尝试重新下载xlm_roberta_text2natsql_schema_item_classifier (trained on CSpider)并解压(密码是3sdu)

atom0407 commented 6 months ago

我重新下载后还是这个报错,OSError: Error no file named pytorch_model.bin, tf_model.h5, model.ckpt.index or flax_model.msgpack found,还是缺少文件,请问楼主解决了嘛?

atom0407 commented 6 months ago

我已解决该问题:(1)修改classifier_model.py文件中的第21行,改为config = AutoConfig.from_pretrained(model_name_or_path) self.plm_encoder = model_class(config) (2)在schema_item_classifier.py文件的第236行加入model.load_state_dict(torch.load(opt.model_name_or_path + "/dense_classifier.pt", map_location=torch.device('cpu')))即可。