wujcan / SGL-TensorFlow

173 stars 42 forks source link

关于计算batch loss的问题 #17

Closed hhmy27 closed 2 years ago

hhmy27 commented 2 years ago

https://github.com/wujcan/SGL/blob/fcfb31f3af9a41fc2f6bc81932733952b91bbc90/model/general_recommender/SGL.py#L454-L460

对于每个batch累计的各类loss,为什么打印的时候,除以的是 data_iter.num_trainings (数据集的大小)呢?

我看LightGCN里面除以的是dataset的batch的个数

hotchilipowder commented 2 years ago

ssl_loss这里的实现是sum,bpr应该是mean。

wujcan commented 2 years ago

https://github.com/wujcan/SGL/blob/fcfb31f3af9a41fc2f6bc81932733952b91bbc90/model/general_recommender/SGL.py#L454-L460

对于每个batch累计的各类loss,为什么打印的时候,除以的是 data_iter.num_trainings (数据集的大小)呢?

我看LightGCN里面除以的是dataset的batch的个数

我们这里计算的是每个instance的平均loss,这样就跟batch size无关了。事实上,这也只是影响loss的相对大小。