Open KangChou opened 2 years ago
train.py 348行改成 input_ids = input_ids.to(device)
解决了吗?
train.py 348行改成 input_ids = input_ids.to(device)
348行是“def calculate_acc(logit, labels, ignore_index=-100):”,这可怎么改..
我也遇到这个问题了,好奇怪啊
我也遇到这个问题了,好奇怪啊
见:https://github.com/yangjianxin1/GPT2-chitchat/issues/79#issuecomment-941684250 中 r1cebank 的回答
python train.py --epochs 1 --batch_size 2 --device 0 --train_path data/train_2w.pkl