yangjianxin1 / GPT2-chitchat

GPT2 for Chinese chitchat/用于中文闲聊的GPT2模型(实现了DialoGPT的MMI思想)
2.99k stars 680 forks source link

error---->train.py #85

Open KangChou opened 2 years ago

KangChou commented 2 years ago

python train.py --epochs 1 --batch_size 2 --device 0 --train_path data/train_2w.pkl

th='data/train.log', log_step=1, lr=2.6e-05, max_grad_norm=2.0, max_len=150, model_config='config/config.json', no_cuda=False, num_workers=0, pad_id=0, patience=0, pretrained_model='', save_model_path='model', sep_id=102, train_path='data/train_2w.pkl', val_num=8000, vocab_path='vocab/vocab.txt', warmup_steps=200)
2021-11-30 13:46:19,241 - INFO - loading training dataset and validating dataset
Traceback (most recent call last):
  File "train.py", line 427, in <module>
    main()
  File "train.py", line 423, in main
    train(model, logger, train_dataset, validate_dataset, args)
  File "train.py", line 270, in train
    drop_last=True
  File "/opt/conda/lib/python3.6/site-packages/torch/utils/data/dataloader.py", line 266, in __init__
    sampler = RandomSampler(dataset, generator=generator)  # type: ignore
  File "/opt/conda/lib/python3.6/site-packages/torch/utils/data/sampler.py", line 104, in __init__
    "value, but got num_samples={}".format(self.num_samples))
ValueError: num_samples should be a positive integer value, but got num_samples=0
Saraooe commented 2 years ago

train.py 348行改成 input_ids = input_ids.to(device)

Dynamicboboo commented 2 years ago

解决了吗?

barbara-su commented 2 years ago

train.py 348行改成 input_ids = input_ids.to(device)

348行是“def calculate_acc(logit, labels, ignore_index=-100):”,这可怎么改..

barbara-su commented 2 years ago

我也遇到这个问题了,好奇怪啊

FuryMartin commented 2 years ago

我也遇到这个问题了,好奇怪啊

见:https://github.com/yangjianxin1/GPT2-chitchat/issues/79#issuecomment-941684250r1cebank 的回答