GanjinZero / RRHF

[NIPS2023] RRHF & Wombat
788 stars 49 forks source link

Runtime error:数据类型报错 #55

Open sqqiao opened 1 month ago

sqqiao commented 1 month ago

作者好,我在复现RRHF时碰到变量类型报错: 我配置fsdp_config进行分布式训练,当我使用--bf16混合精度时,报错: return torch.embedding(weight, input, padding_idx, scale_grad_by_freq, sparse) RuntimeError: Expected tensor for argument #1 'indices' to have one of the following scalar types: Long, Int; but got CUDABFloat16Type instead (while checking arguments for embedding)

如果不使用bf16和tf32,报错: return torch.embedding(weight, input, padding_idx, scale_grad_by_freq, sparse) RuntimeError: Expected tensor for argument #1 'indices' to have one of the following scalar types: Long, Int; but got torch.cuda.FloatTensor instead (while checking arguments for embedding)

我的fsdp_config配置如图 1

使用的模型是llama3-8b,或者是tokenizer需要重新配置一下吗?