shibing624 / textgen

TextGen: Implementation of Text Generation models, include LLaMA, BLOOM, GPT2, BART, T5, SongNet and so on. 文本生成模型,实现了包括LLaMA,ChatGLM,BLOOM,GPT2,Seq2Seq,BART,T5,UDA等模型的训练和预测,开箱即用。
Apache License 2.0
926 stars 107 forks source link

Bart长文本训练问题 #14

Open YoungChanYY opened 1 year ago

YoungChanYY commented 1 year ago

我用Bart训练代码,每个训练数据都为:输入文本约1000字符,输出文本长3-5万字符。训练几个epoch后会出错,错误信息如下所示。 但是控制输入和输出的字符长度,比如都为100字符左右,则训练正常,没有报错。

请问一下:Bart模型的输入输出长度有什么要求吗,这应该是内部embedding维度出错了吧。谢谢。

RuntimeError: CUDA error: CUBLAS_STATUS_EXECUTION_FAILED when calling cublasSgemmStridedBatched( handle, opa, opb, m, n, k, &alpha, a, lda, stridea, b, ldb, strideb, &beta, c, ldc, stridec, num_batches)

YoungChanYY commented 1 year ago

出错的位置好像是在predict位置。当取消在训练过程中进行eval处理时,训练得以正常进行。大佬

Traceback (most recent call last): File "train_bart_text2abc.py", line 180, in main() File "train_bart_text2abc.py", line 163, in main model.train_model(train_df, eval_data=eval_df, split_on_space=True, matches=count_matches) File "textgen/seq2seq/bart_seq2seq_model.py", line 452, in train_model kwargs, File "textgen/seq2seq/bart_seq2seq_model.py", line 983, in train kwargs, File "textgen/seq2seq/bart_seq2seq_model.py", line 1153, in eval_model preds = self.predict(to_predict, split_on_space=split_on_space) File "textgen/seq2seq/bart_seq2seq_model.py", line 1310, in predict num_return_sequences=self.args.num_return_sequences, File "/usr/local/lib/python3.7/dist-packages/torch/autograd/grad_mode.py", line 27, in decorate_context return func(args, kwargs) File "/usr/local/lib/python3.7/dist-packages/transformers/generation/utils.py", line 1400, in generate model_kwargs, File "/usr/local/lib/python3.7/dist-packages/transformers/generation/utils.py", line 2183, in greedy_search output_hidden_states=output_hidden_states, File "/usr/local/lib/python3.7/dist-packages/torch/nn/modules/module.py", line 1110, in _call_impl return forward_call(input, kwargs) File "/usr/local/lib/python3.7/dist-packages/transformers/models/bart/modeling_bart.py", line 1389, in forward return_dict=return_dict, File "/usr/local/lib/python3.7/dist-packages/torch/nn/modules/module.py", line 1110, in _call_impl return forward_call(*input, *kwargs) File "/usr/local/lib/python3.7/dist-packages/transformers/models/bart/modeling_bart.py", line 1268, in forward return_dict=return_dict, File "/usr/local/lib/python3.7/dist-packages/torch/nn/modules/module.py", line 1110, in _call_impl return forward_call(input, kwargs) File "/usr/local/lib/python3.7/dist-packages/transformers/models/bart/modeling_bart.py", line 1124, in forward use_cache=use_cache, File "/usr/local/lib/python3.7/dist-packages/torch/nn/modules/module.py", line 1110, in _call_impl return forward_call(*input, *kwargs) File "/usr/local/lib/python3.7/dist-packages/transformers/models/bart/modeling_bart.py", line 431, in forward output_attentions=output_attentions, File "/usr/local/lib/python3.7/dist-packages/torch/nn/modules/module.py", line 1110, in _call_impl return forward_call(input, **kwargs) File "/usr/local/lib/python3.7/dist-packages/transformers/models/bart/modeling_bart.py", line 275, in forward attn_output = torch.bmm(attn_probs, value_states) RuntimeError: CUDA error: CUBLAS_STATUS_EXECUTION_FAILED when calling cublasSgemmStridedBatched( handle, opa, opb, m, n, k, &alpha, a, lda, stridea, b, ldb, strideb, &beta, c, ldc, stridec, num_batches) ../aten/src/ATen/native/cuda/Indexing.cu:650: indexSelectSmallIndex: block: [0,0,0], thread: [0,0,0] Assertion srcIndex < srcSelectDimSize failed.

shibing624 commented 1 year ago

我看看evaluate的逻辑

YoungChanYY commented 1 year ago

多谢。

我看到另一处地方,应该有些问题: 在textgen/seq2seq/bart_seq2seq_utils.py的preprocess_data_bart(data)函数中,对target_ids 数据处理的问题和建议如下,大佬看看对不对。谢谢!

def preprocess_data_bart(data): input_text, target_text, tokenizer, args = data ...... target_ids = tokenizer.batch_encode_plus( [target_text],

max_length=args.max_seq_length, #原代码

    max_length=args.max_length,            #建议代码
    padding="max_length",
    return_tensors="pt",
    truncation=True,
)
shibing624 commented 1 year ago

对的,fixed: https://github.com/shibing624/textgen/commit/7a0be5931234262165d32fc0f915af822e0b1665

stale[bot] commented 9 months ago

This issue has been automatically marked as stale because it has not had recent activity. It will be closed if no further activity occurs. Thank you for your contributions.(由于长期不活动,机器人自动关闭此问题,如果需要欢迎提问)