shibing624 / textgen

TextGen: Implementation of Text Generation models, include LLaMA, BLOOM, GPT2, BART, T5, SongNet and so on. 文本生成模型,实现了包括LLaMA,ChatGLM,BLOOM,GPT2,Seq2Seq,BART,T5,UDA等模型的训练和预测,开箱即用。
Apache License 2.0
938 stars 109 forks source link

模型训练出错 #41

Closed svjack closed 1 year ago

svjack commented 1 year ago

使用 textgen/examples/llama/training_llama_demo.py 微调模型:https://huggingface.co/shibing624/chinese-llama-plus-13b-hf 使用示例数据集data/zh_csc_train.tsv 有下面的错误 assertion srcindex < srcselectdimsize failed.

该用模型 如:https://huggingface.co/shibing624/chinese-alpaca-plus-13b-hf 则正常。

shibing624 commented 1 year ago

嗯,我留意到了此问题,llama-plus-13b 本地直接预测也会出此错误,alpaca不会。

可能是transformers升级导致的问题,还在排查。

训练13b,可以用其他模型替代,如alpaca-13b, ziya-13b

svjack commented 1 year ago

llama model predict 方法感觉对外暴露的(不由default arg指定的)GenerationConfig 参数感觉有点少 **kwargs 应该考虑重载generation_config 会不会更好一些呢?

shibing624 commented 1 year ago

有 kwargs: https://github.com/shibing624/textgen/blob/main/textgen/llama/llama_model.py#L519

svjack commented 1 year ago

有 kwargs: https://github.com/shibing624/textgen/blob/main/textgen/llama/llama_model.py#L519

像这种参数怎么改呢?

repetition_penalty=self.args.repetition_penalty,
length_penalty=self.args.length_penalty,
svjack commented 1 year ago

https://github.com/huggingface/transformers/issues/24104

shibing624 commented 1 year ago

这样写:https://github.com/shibing624/textgen/blob/main/examples/llama/training_llama_demo.py#L53 写进model_args 就可以,会自动覆盖默认的参数。

svjack commented 1 year ago

这样写:https://github.com/shibing624/textgen/blob/main/examples/llama/training_llama_demo.py#L53 写进model_args 就可以,会自动覆盖默认的参数。

感觉这里面的一些参数不应该在初始化时指定 而应该在生成时是动态的

shibing624 commented 1 year ago

初始化时指定的是默认的,生成时指定可以覆盖默认的,类似于max_length 参数,其他的参数也会覆盖默认的,这个我加下。

shibing624 commented 1 year ago

done