Jeryi-Sun / SPACES-Pytorch

苏神SPACE pytorch版本复现
MIT License
41 stars 4 forks source link

请问要如何用训练好的模型直接输入新文本得到摘要呢? #7

Open PolarisRisingWar opened 2 years ago

PolarisRisingWar commented 2 years ago

您好,我看到在SPACES原项目里是用final.py里面的函数来直接输出模型,想请问这个PyTorch版写法里怎么写呢?

Jeryi-Sun commented 2 years ago

使用seq2seq_model 里的class AutoSummary(AutoRegressiveDecoder) 即可

PolarisRisingWar commented 2 years ago

您好,我现在在文件夹下新建了一个py文件,代码为:

import seq2seq_model as seq2seq
import torch

def predict(text, topk=3):
    model=seq2seq.GenerateModel().to('cuda:1')
    model.load_state_dict(torch.load('checkpoint/seq2seq-49.pkl')['model'])
    autosummary = seq2seq.AutoSummary(
        start_id=model.tokenizer.cls_token_id,
        end_id=model.tokenizer.sep_token_id,
        maxlen=1024 // 2,
        model=model
    )
    pred_summary = autosummary.generate(text,topk)
    # 返回
    return pred_summary

predict('哈尔滨银行股份有限公司阿城支行与魏春超借款合同纠纷一审民事 判决 书黑龙江省哈尔滨市阿城区人民法院民 事 判 决 书(2017)黑0112民初1643号原告哈尔滨银行股份有限公司阿城支行,所在地哈尔滨市阿城区延川大街。负责人车大伟,职务行长。委托代理人张晓庆,男,1982年9月17日出生,汉族,哈尔滨银行股份有限公司阿城支行职员,现住所哈尔滨市香坊区。委托代理人辛晓伟,男,1987年6月18日出生,汉族,哈尔滨银行股份有限公司阿城支行职员,现住所哈尔滨市。被告魏春超,男,1984年4月30日出生,汉族,农民,现住所哈尔滨市阿城区。原告哈尔滨银行股份有限公司阿城支行与被告魏春超借款合同纠纷一案,本院受理后,依法组成合议庭,公开开庭进行了审理,原告委托代理人张晓庆、辛晓伟到庭参加诉讼,被告魏春超经本院合法传唤未到庭参加诉讼,本案现已审理终结。原告哈尔滨银行股份有限公司阿城支行向本院提出诉讼请求:1、请求被告魏春超偿还借款本金343500元及利息(计算至履行完毕时止);2、诉讼费用由被告负担。事实及理由:2014年6月30日,原告与被告魏春超签订个人农户种植贷款借款合同,约定:借款人魏春超在原告处借款343500元,借款期限自2014年6月30日至2015年7月30日,年利率为12.96%。原告按合同约定履行放款义务,被告未按约定履行还款及担保义务。被告魏春超经合法传唤未到庭参加诉讼,亦未向本院提交答辩状及抗辩证据。在本院开庭审理过程中,原告为证明其诉讼主张的事实成立,举示了以下证据:证据一、个人农户种植贷款借款合同复印件一份(已与原件核对无误)。拟证实2014年6月30日,原告与被告魏春超签订个人农户种植贷款借款合同,约定:借款人魏春超在原告处借款343500元,借款期限自2014年6月30日至2015年7月30日,年利率为12.96%,贷款发生逾期,逾期罚息利率为本合同约定借款利率的150%。在本合同期限内,借款的实际放款日和约定还款日以借据为准,借据为本合同组成部分,与本合同具有同等法律效力。证据二、哈尔滨银行借款凭证复印件一份(已与原件核对无误)。拟证实2014年6月30日,魏春超在原告处借款343500元,借款期限为13个月,年利率为12.96%,还款时间为2015年7月30日。证据三、身份证及户口复印件各一份。拟证实被告身份。证据四、营业执照及组织机构代码证复印件一份。拟证实原告经营范围及期限。证据五、哈尔滨市阿城区阿什河街东环村民委员会证明一份。拟证实魏春超系该村村民,常年在外打工,下落不明。审判长 :因被告经本院合法传唤未到庭参加诉讼,视为自愿放弃质证权利,应承担原告举证对其不利后果,故本院对原告提交五份证据认定为有效证据。因被告魏春超、孙振广未向本院提供任何反驳原告诉讼请求的证据和放弃出庭抗辩权利,应承担原告举证对其不利的后果,上述证据来源合法与本案具有关联性,符合证据规则认定的条件。故本院对原告提交五份证据予以确认。结合当事人的诉讼请求、当庭陈述和本院对当事人提交证据的分析判断,本院认定案件事实如下:2014年6月30日,哈尔滨银行股份有限公司阿城支行与魏春超签订个人农户种植贷款借款合同,约定:借款人魏春超在原告处借款343500元,借款期限自2014年6月30日至2015年7月30日,年利率为12.96%,贷款发生逾期,逾期罚息利率为本合同约定借款利率的150%。在本合同期限内,借款的实际放款日和约定还款日以借据为准,借据为本合同组成部分,与本合同具有同等法律效力。同日,原告按约定履行放款义务,魏春超给原告出具借款凭证一份,凭证记载:魏春超在原告处借款343500元,借款期限为13个月,年利率为12.96%,还款时间为2015年7月30日。此款到期后,二被告未按约定履行还款及担保义务。现原告诉至本院,请求判令被告魏春超偿还借款本金343500元及利息(计算方法:以本金343500元为基准,按合同约定利率及罚息,自2014年6月30日起至履行完毕时止)。本院认为,原告与被告魏春超签订个人农户种植贷款借款合同,系双方当事人的真实意思表示,且不违反法律、行政法规的禁止性规定,合法有效,双方应按约定履行合同。合同签订后,')

就是仿照seq2seq_model.py里面对应部分的代码来写的。predict的输入是复制了一篇正文。cuda:1是训练时使用的GPU,如果不加上会报错。

现在还是报错。完整输出:

Building prefix dict from the default dictionary ...
Loading model from cache (mypath)/sfzy/jiebatemp/jieba.cache
Loading model cost 0.557 seconds.
Prefix dict has been built successfully.
(mypath)/bert_cache/nezha-chinese-base/pytorch_model.bin loaded!
Traceback (most recent call last):
  File "(mypath)/projects/SPACES-Pytorch/final_whj.py", line 17, in <module>
    predict('唐鲜明与何伟华、深圳市华名威电汽车服务有限公司侵权责任纠纷一审民事判决书。广东省深圳市宝安区人民法院')
  File "(mypath)/projects/SPACES-Pytorch/final_whj.py", line 13, in predict
    pred_summary = autosummary.generate(text,topk)
  File "(mypath)/projects/SPACES-Pytorch/seq2seq_model.py", line 341, in generate
    output_ids = self.beam_search([token_ids, segment_ids],
  File "(mypath)/projects/SPACES-Pytorch/snippets.py", line 124, in beam_search
    scores, states = self.predict(
  File "(mypath)/projects/SPACES-Pytorch/snippets.py", line 82, in new_predict
    prediction = predict(self, inputs, output_ids, states)
  File "(mypath)/projects/SPACES-Pytorch/seq2seq_model.py", line 302, in predict
    seq2seq_predictions, copy_predictions = self.model(torch.tensor(token_ids, device=device), torch.tensor(segment_ids, device=device))
  File "(mypath)/conda_envs/envlegalai1/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1102, in _call_impl
    return forward_call(*input, **kwargs)
  File "(mypath)/projects/SPACES-Pytorch/seq2seq_model.py", line 213, in forward
    seq2seq_predictions,  hidden_state = self.bert_model(token_ids, token_type_ids)
ValueError: not enough values to unpack (expected 2, got 1)

想请问您能看出这是什么原因吗?