THUDM / WebGLM

WebGLM: An Efficient Web-enhanced Question Answering System (KDD 2023)
Apache License 2.0
1.57k stars 135 forks source link

input length of input ids的长度大于1024 #5

Open linuxonly801 opened 1 year ago

linuxonly801 commented 1 year ago

配置好本地环境,使用WebGLM-2B模型。提问:Is HER2 gene a good target for treating cancer?

出现如下报错:

Input length of input_ids is 1056, but max_length is set to 1024. This can lead to unexpected behavior. You should consider increasing max_new_tokens. Traceback (most recent call last): File "cli_demo.py", line 21, in for results in webglm.stream_query(question): File "/media/WebGLM/model/modeling_webglm.py", line 49, in stream_query outputs = self.model.generate(inputs, max_length=1024, eos_token_id = self.tokenizer.eop_token_id, pad_token_id=self.tokenizer.eop_token_id) File "/usr/local/miniconda3/lib/python3.8/site-packages/torch/autograd/grad_mode.py", line 27, in decorate_context return func(*args, *kwargs) File "/usr/local/miniconda3/lib/python3.8/site-packages/transformers/generation/utils.py", line 1515, in generate return self.greedy_search( File "/usr/local/miniconda3/lib/python3.8/site-packages/transformers/generation/utils.py", line 2332, in greedy_search outputs = self( File "/usr/local/miniconda3/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1130, in _call_impl return forward_call(input, kwargs) File "/root/.cache/huggingface/modules/transformers_modules/modeling_glm.py", line 902, in forward model_output = self.glm(input_ids, position_ids, attention_mask, mems=mems, kwargs) File "/usr/local/miniconda3/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1130, in _call_impl return forward_call(*input, *kwargs) File "/root/.cache/huggingface/modules/transformers_modules/modeling_glm.py", line 783, in forward transformer_output = self.transformer(embeddings, position_ids, attention_mask, mems) File "/usr/local/miniconda3/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1130, in _call_impl return forward_call(input, kwargs) File "/root/.cache/huggingface/modules/transformers_modules/modeling_glm.py", line 595, in forward hidden_states = layer(args, mem=mem_i) File "/usr/local/miniconda3/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1130, in _call_impl return forward_call(input, **kwargs) File "/root/.cache/huggingface/modules/transformers_modules/modeling_glm.py", line 422, in forward layernorm_input = hidden_states + attention_output RuntimeError: CUDA error: device-side assert triggered CUDA kernel errors might be asynchronously reported at some other API call,so the stacktrace below might be incorrect. For debugging consider passing CUDA_LAUNCH_BLOCKING=1.

ilovesouthpark commented 1 year ago

modeling_webglm.py里把1024改大

mikestut commented 1 year ago

outputs = self.model.generate(**inputs, max_length=2048, eos_token_id = self.tokenizer.eop_token_id, pad_token_id=self.tokenizer.eop_token_id) 修改后还是报错,求解决!! image

ilovesouthpark commented 1 year ago

用cpu跑会发现真正的错误原因是IndexError: index out of range in self,这个问题超过我的能力范围去解决了,但是应该是个普遍的问题,看看开发团队能否提供额外的参数让大家方便调整。

traveler-vee commented 1 year ago

关注中

hanjingsu commented 1 year ago

同遇到问题,中文问的时候会碰到,英文问目前还正常,看起出来跟中文搜索结果截出来的关键字没按照预订长度有关系

TailyDuan commented 1 year ago

IndexError: index out of range in self

WnQinm commented 5 months ago

我搜索英文也会遇到这个问题,看到作者说不能修改max_length,所以只能做截断,但是我没有找到显式的截断api

在我这里,报错input_ids长度大于1024的原因不是用户输入的prompt过长,而是作者代码没有对搜索引擎搜索到的reference按规定长度截断,并将他们直接添加到了prompt中,导致input_ids大小超过1024。

解决方法是修改modeling_webglm.py中的query函数或stream_query函数。计算每个ref对应的token长度,限制总的prompt长度:

def query(self, question):
    refs = self.ref_retriever.query(question)
    if not refs:
        return { "references": [], "answer": "" }
    prompt = ''
    question = f'Question: {question}\\Answer: [gMASK]'
    total_token_num = self.tokenizer(question, return_tensors="pt").input_ids.shape[1]
    for ix, ref in enumerate(refs):
        txt = ref["text"]
        prompt_tmp = f'Reference [{ix+1}]: {txt}' '\\'
        prompt_tmp_token_num = self.tokenizer(prompt_tmp, return_tensors="pt").input_ids.shape[1]
        if total_token_num + prompt_tmp_token_num < 900:
            prompt += prompt_tmp
            total_token_num += prompt_tmp_token_num
        else:
            break
    prompt += question
    inputs = self.tokenizer(prompt, return_tensors="pt")
    # other code

尽管这样做很简单(会由于超过长度限制的ref[i]而忽略符合长度限制的ref[i+1])且效率有点低,但确实对我有用:)


我研究了一下代码,model/retriever/extracting/__init__.pyExtractor类的_pre_filter方法应该是在约束检索到的每个url对应页面中每一个段落的长度