THUDM / SwissArmyTransformer

SwissArmyTransformer is a flexible and powerful library to develop your own Transformer variants.
https://THUDM.github.io/SwissArmyTransformer
Apache License 2.0
951 stars 90 forks source link

请教一个问题,使用mp_size=2时的loss应该怎么写 #131

Open kunden0612 opened 1 year ago

kunden0612 commented 1 year ago
logits, *mems = model(inputs_ids, position_ids, attention_mask)
# print(logits.shape)
loss_func = CrossEntropyLoss(ignore_index=-100)
loss = loss_func(logits.view(-1, logits.size(-1)).float(), labels.view(-1))``

我是这样写的loss计算方式,会出现一个/opt/conda/conda-bld/pytorch_1670525539683/work/aten/src/ATen/native/cuda/Loss.cu:242: nll_loss_forward_reduce_cuda_kernel_2d: block: [0,0,0], thread: [15,0,0] Assertion t >= 0 && t < n_classes failed.`` 错误

1049451037 commented 1 year ago

是不是你forward的时候传了parallel_output=Truehttps://github.com/THUDM/SwissArmyTransformer/blob/main/sat/transformer_defaults.py#L146

导致输出还没有聚合,分散在多个rank里