sunzeyeah / RLHF

Implementation of Chinese ChatGPT
283 stars 35 forks source link

为什么训练的时候要加入<sep> token? #2

Closed Nipi64310 closed 1 year ago

Nipi64310 commented 1 year ago

Hi, @sunzeyeah Thank you for sharing your amazing work.

为什么训练的时候要加入 token,预测不需要

https://github.com/sunzeyeah/RLHF/blob/master/src/data/data.py#L167 训练的时候会加入sep id 5,
https://huggingface.co/sunzeyeah/pangu-2.6B-sft 这里提供的示例代码,会将 切成多个id 示例如下:

>>> from transformers import TextGenerationPipeline, AutoTokenizer, AutoModelForCausalLM
>>> import torch
>>> 
>>> model_id = 'pangu-2.6B-sft‘
>>> tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True,use_cache=False)
Explicitly passing a `revision` is encouraged when loading a model with custom code to ensure no malicious code has been contributed in a newer revision.
>>> tokenizer('你好啊','答:')
Building prefix dict from the default dictionary ...
Loading model from cache /tmp/jieba.cache
Loading model cost 0.690 seconds.
Prefix dict has been built successfully.
{'input_ids': [1, 5772, 173, 5, 3330, 17, 9], 'token_type_ids': [0, 0, 1, 1], 'attention_mask': [1, 1, 1, 1, 1, 1, 1]}
>>> tokenizer('你好啊<sep>答:')
{'input_ids': [1, 5772, 173, 13, 1674, 2314, 21413, 716, 3330, 17, 9], 'token_type_ids': [0, 0, 0, 0, 0, 0, 0, 0, 0], 'attention_mask': [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]}
>>>
sunzeyeah commented 1 year ago

你好,在使用pangu类模型的时候,tokenization_gptpangu.pytokenize()函数会先用jieba进行分词。而直接pip install jieba的话,默认会将<>直接切分开,哪怕使用jieba.add_word("<sep>")也没有作用,因为jieba直接hardcode了会自动切分的token,其中就包括了<>

因此需要将jieba代码clone到本地,修改jieba/__init__.pyre_han_default的取值,具体改动如下:

修改完成后使用pip install .进行本地编译安装,替换原有jieba。安装完成后,在代码中加入jieba.add_word("<sep>")(该代码已加入tokenization_gptpangu.py),这样即可解决将<sep>切分为多个id的情况

sunzeyeah commented 1 year ago

关于是否一定需要在训练阶段加入<sep>,个人觉得不是绝对的yes或者no,更多的是一个偏经验的做法。主要动机就是给模型一个token或者信号,来分隔开promptanswer

Nipi64310 commented 1 year ago

关于是否一定需要在训练阶段加入<sep>,个人觉得不是绝对的yes或者no,更多的是一个偏经验的做法。主要动机就是给模型一个token或者信号,来分隔开promptanswer

感谢回复!