THUDM / ChatGLM2-6B

ChatGLM2-6B: An Open Bilingual Chat LLM | 开源双语对话语言模型
Other
15.72k stars 1.85k forks source link

[BUG/Help] 重复惩罚未生效 #83

Open 1MLightyears opened 1 year ago

1MLightyears commented 1 year ago

Is there an existing issue for this?

Current Behavior

我按照README.md中的步骤成功部署了ChatGLM2. 然而当我试图测试它的内容生成速度时发现,它倾向于对于同一prompt给出同样的内容。 虽然#43 也提及这一问题,但repetition_penalty选项并不生效(或者是我填错了地方?但并没有报错)。使用IPython的%timeit进行测试时,由于每次的prompt均相同,尽管设置了repetition_penalty=0.8,生成的内容仍然一字不差;如果将其调为较低的值,甚至直接无内容输出(假死并在每次Ctrl+C时停在同一语句上)

Expected Behavior

当设置较小的repetition_penalty时,给出不同的内容

Steps To Reproduce

>>> from transformers import AutoTokenizer, AutoModel
>>> tokenizer = AutoTokenizer.from_pretrained("THUDM/chatglm2-6b", trust_remote_code=True)
>>> model = AutoModel.from_pretrained("THUDM/chatglm2-6b", trust_remote_code=True, device='cuda')
>>> model = model.eval()
>>> hist = []
>>> def chat(prompt, r_p=1):
    ...:     global hist
    ...:     resp, hist = model.chat(tokenizer, prompt, history=hist, repetition_penalty=r_p)
    ...:     print(resp)
    ...:     return resp, hist
>>> %timeit chat("以下将多次向你发送同样内容,你可以随意决定回答内容。发送这些内容的目的是为了衡量你的回复速度。", 0.8)
好的,我会尽快回复你的内容。请注意,回复速度可能会受到当前工作量、网络质量和其他因素的影响,因此可能会有一定的延迟。
好的,我会尽快回复你的内容。请注意,回复速度可能会受到当前工作量、网络质量和其他因素的影响,因此可能会有一定的延迟。
好的,我会尽快回复你的内容。请注意,回复速度可能会受到当前工作量、网络质量和其他因素的影响,因此可能会有一定的延迟。
好的,我会尽快回复你的内容。请注意,回复速度可能会受到当前工作量、网络质量和其他因素的影响,因此可能会有一定的延迟。
好的,我会尽快回复你的内容。请注意,回复速度可能会受到当前工作量、网络质量和其他因素的影响,因此可能会有一定的延迟。
好的,我会尽快回复你的内容。请注意,回复速度可能会受到当前工作量、网络质量和其他因素的影响,因此可能会有一定的延迟。
好的,我会尽快回复你的内容。请注意,回复速度可能会受到当前工作量、网络质量和其他因素的影响,因此可能会有一定的延迟。
好的,我会尽快回复你的内容。请注意,回复速度可能会受到当前工作量、网络质量和其他因素的影响,因此可能会有一定的延迟。
23.6 s ± 5.34 s per loop (mean ± std. dev. of 7 runs, 1 loop each)
>>> %timeit chat("以下将多次向你发送同样内容,你可以随意决定回答内容。发送这些内容的目的是为了衡量你的回复速度。", 0.5)
# 长期无输出
# 按下Ctrl+C后的Stack
^C---------------------------------------------------------------------------
KeyboardInterrupt                         Traceback (most recent call last)
Cell In[20], line 1
----> 1 get_ipython().run_line_magic('timeit', 'chat("以下将多次向你发送同样内容,你可以随意决定回答内容,但请尽量避免回复重复内容。发送这些内容的目的是为了衡量你的回复速度。", 0.5)')

File ~/miniconda3/lib/python3.10/site-packages/IPython/core/interactiveshell.py:2417, in InteractiveShell.run_line_magic(self, magic_name, line, _stack_depth)
   2415     kwargs['local_ns'] = self.get_local_scope(stack_depth)
   2416 with self.builtin_trap:
-> 2417     result = fn(*args, **kwargs)
   2419 # The code below prevents the output from being displayed
   2420 # when using magics with decodator @output_can_be_silenced
   2421 # when the last Python token in the expression is a ';'.
   2422 if getattr(fn, magic.MAGIC_OUTPUT_CAN_BE_SILENCED, False):

File ~/miniconda3/lib/python3.10/site-packages/IPython/core/magics/execution.py:1170, in ExecutionMagics.timeit(self, line, cell, local_ns)
   1168 for index in range(0, 10):
   1169     number = 10 ** index
-> 1170     time_number = timer.timeit(number)
   1171     if time_number >= 0.2:
   1172         break

File ~/miniconda3/lib/python3.10/site-packages/IPython/core/magics/execution.py:158, in Timer.timeit(self, number)
    156 gc.disable()
    157 try:
--> 158     timing = self.inner(it, self.timer)
    159 finally:
    160     if gcold:

File <magic-timeit>:1, in inner(_it, _timer)

Cell In[17], line 3, in chat(prompt)
      1 def chat(prompt):
      2     global hist
----> 3     resp, hist = m_e.chat(tokenizer, prompt, history=hist, repetition_penalty=0.5)
      4     print(resp)
      5     return resp, hist

File ~/miniconda3/lib/python3.10/site-packages/torch/utils/_contextlib.py:115, in context_decorator.<locals>.decorate_context(*args, **kwargs)
    112 @functools.wraps(func)
    113 def decorate_context(*args, **kwargs):
    114     with ctx_factory():
--> 115         return func(*args, **kwargs)

File ~/.cache/huggingface/modules/transformers_modules/THUDM/chatglm2-6b/5f2cf3359164692d1caec34a6f31c80c658b7ffa/modeling_chatglm.py:948, in ChatGLMForConditionalGeneration.chat(self, tokenizer, query, history, max_length, num_beams, do_sample, top_p, temperature, logits_processor, **kwargs)
    945 gen_kwargs = {"max_length": max_length, "num_beams": num_beams, "do_sample": do_sample, "top_p": top_p,
    946               "temperature": temperature, "logits_processor": logits_processor, **kwargs}
    947 inputs = self.build_inputs(tokenizer, query, history=history)
--> 948 outputs = self.generate(**inputs, **gen_kwargs)
    949 outputs = outputs.tolist()[0][len(inputs["input_ids"][0]):]
    950 response = tokenizer.decode(outputs)

File ~/miniconda3/lib/python3.10/site-packages/torch/utils/_contextlib.py:115, in context_decorator.<locals>.decorate_context(*args, **kwargs)
    112 @functools.wraps(func)
    113 def decorate_context(*args, **kwargs):
    114     with ctx_factory():
--> 115         return func(*args, **kwargs)

File ~/miniconda3/lib/python3.10/site-packages/transformers/generation/utils.py:1572, in GenerationMixin.generate(self, inputs, generation_config, logits_processor, stopping_criteria, prefix_allowed_tokens_fn, synced_gpus, assistant_model, streamer, **kwargs)
   1564     input_ids, model_kwargs = self._expand_inputs_for_generation(
   1565         input_ids=input_ids,
   1566         expand_size=generation_config.num_return_sequences,
   1567         is_encoder_decoder=self.config.is_encoder_decoder,
   1568         **model_kwargs,
   1569     )
   1571     # 13. run sample
-> 1572     return self.sample(
   1573         input_ids,
   1574         logits_processor=logits_processor,
   1575         logits_warper=logits_warper,
   1576         stopping_criteria=stopping_criteria,
   1577         pad_token_id=generation_config.pad_token_id,
   1578         eos_token_id=generation_config.eos_token_id,
   1579         output_scores=generation_config.output_scores,
   1580         return_dict_in_generate=generation_config.return_dict_in_generate,
   1581         synced_gpus=synced_gpus,
   1582         streamer=streamer,
   1583         **model_kwargs,
   1584     )
   1586 elif is_beam_gen_mode:
   1587     if generation_config.num_return_sequences > generation_config.num_beams:

File ~/miniconda3/lib/python3.10/site-packages/transformers/generation/utils.py:2632, in GenerationMixin.sample(self, input_ids, logits_processor, stopping_criteria, logits_warper, max_length, pad_token_id, eos_token_id, output_attentions, output_hidden_states, output_scores, return_dict_in_generate, synced_gpus, streamer, **model_kwargs)
   2629 next_token_logits = outputs.logits[:, -1, :]
   2631 # pre-process distribution
-> 2632 next_token_scores = logits_processor(input_ids, next_token_logits)
   2633 next_token_scores = logits_warper(input_ids, next_token_scores)
   2635 # Store scores, attentions and hidden_states when required

File ~/miniconda3/lib/python3.10/site-packages/transformers/generation/logits_process.py:92, in LogitsProcessorList.__call__(self, input_ids, scores, **kwargs)
     90         scores = processor(input_ids, scores, **kwargs)
     91     else:
---> 92         scores = processor(input_ids, scores)
     93 return scores

File ~/.cache/huggingface/modules/transformers_modules/THUDM/chatglm2-6b/5f2cf3359164692d1caec34a6f31c80c658b7ffa/modeling_chatglm.py:53, in InvalidScoreLogitsProcessor.__call__(self, input_ids, scores)
     52 def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
---> 53     if torch.isnan(scores).any() or torch.isinf(scores).any():
     54         scores.zero_()
     55         scores[..., 5] = 5e4

KeyboardInterrupt:

Environment

- OS: Ubuntu 22.04
- Python: 3.10.10
- Transformers: 4.30.2
- PyTorch: 2.0.1+cu118
- CUDA Support (`python -c "import torch; print(torch.cuda.is_available())"`) : True

Anything else?

No response

tjsky commented 1 year ago

Repetition Penalty是 = 1不惩罚重复,> 1时开始惩罚重复,< 1时鼓励重复。一般不要设置小于0.5,小于0.5极可能出现灾难级的复读。

其余参数为web_demo.py初始值