DLLXW / baby-llama2-chinese

用于从头预训练+SFT一个小参数量的中文LLaMa2的仓库;24G单卡即可运行得到一个具备简单中文问答能力的chat-llama2.
MIT License
2.42k stars 296 forks source link

dataset_sft.py中loss_mask的切片为什么和X一致? #13

Open BigaGrayWolf opened 1 year ago

BigaGrayWolf commented 1 year ago

dataset_sft.py中的第50行 loss_mask=np.array(loss_mask[:-1]) 个人觉得应该改为 loss_mask=np.array(loss_mask[1:]) 和Y的切片一致,因为最后算Loss的时候,是将Y的值和模型的输出进行比较。

DLLXW commented 1 year ago

dataset_sft.py中的第50行 loss_mask=np.array(loss_mask[:-1]) 个人觉得应该改为 loss_mask=np.array(loss_mask[1:]) 和Y的切片一致,因为最后算Loss的时候,是将Y的值和模型的输出进行比较。

这个问题我在写的时候也有思考过,但是根据我写的那个逻辑,好像没啥问题,可以具体打出来看下是不是符合预期。

Niculuse commented 12 months ago

个人认为应该和X的切片一致。Y是标签,是X的标签,计算loss的时候,mask应该指明输入中哪些位置不参与计算,所以是和X切片一致。