Closed qz701731tby closed 1 year ago
sparse_multilabel_categorical_crossentropy这个函数计算出来的loss没有除以batch_size,训练的代码中也没有处理。
https://github.com/xhw205/GPLinker_torch/blob/d065bd36eb0b26676db4be17fe88203796eef12d/nets/gpNet.py#L31 或许torch.sum()应该添加一个dim=1的参数设置, torch.sum(pos_loss + neg_loss, dim=1) 。
收到,谢谢~
sparse_multilabel_categorical_crossentropy这个函数计算出来的loss没有除以batch_size,训练的代码中也没有处理。