shamangary / Keras-MNIST-center-loss-with-visualization

An implementation for mnist center loss training and visualization
75 stars 18 forks source link

中心损失中中心值未更新 #5

Closed longzeyilang closed 6 years ago

longzeyilang commented 6 years ago

按照上诉的方法,中心值没有更新,每次还是取随机值,更新需要进行梯度更新或者每次重新计算中心值

shamangary commented 6 years ago

如果中心值每次都取隨機值,後續的聚類點應該會亂跳動,然而後續的聚類點是逐漸穩定,請問你這個講法是有什麼根據嗎?

longzeyilang commented 6 years ago

@shamangary 你好,合理的做法,在每次batch中,将每次的计算中心进行更新,可以参考下面链接 https://github.com/Kakoedlinnoeslovo/center_loss/blob/master/Network.py CenterLayer

shamangary commented 6 years ago

@longzeyilang 你好,我大概了解你指的應該是下面這行https://github.com/Kakoedlinnoeslovo/center_loss/blob/master/Network.py#L61

我簡述我對keras和center loss的理解如下,在keras中加入自定義的weight可如下簡單範例 https://keras.io/layers/writing-your-own-keras-layers/ 該weight是會自動update的,而我這參考https://kexue.fm/archives/4493 所寫的center loss是採用了 Embedding layer: https://github.com/keras-team/keras/blob/master/keras/layers/embeddings.py 其weight即對應於center,而查找的label則返還對應的weight。

而在原始center loss paper的Eq.(4)他所用的delta_C_j其實就是你貼的連結所實作的, 我覺得你貼的連結的確是比較重現原作,但這似乎不是因為Embedding的weight沒有update, 而是因為該連結做了一個額外的update:https://github.com/Kakoedlinnoeslovo/center_loss/blob/master/Network.py#L57 我看起來像是外加的,因為原作也提過這概念不是直接微分的, 所以用Embedding做的比較像是簡易版本(用查找直接用微分值update),但是沒有完全體現原作這樣。

longzeyilang commented 6 years ago

@shamangary 你好,用Embedding layer只是计算每个batch的重复中心,没有将之前的batch连接起来,所以用Embedding layer不太合适,我在原来基础进行修改,结果如下: def call(self,x,mask=None): #x=[features,label] label = tf.reshape(x[1], [-1]) label=tf.to_int32(label, name='ToInt32') centers_batch = tf.gather(self.centers, label) diff = (1 - self.alpha_center) * (centers_batch - x[0]) new_centers = tf.scatter_sub(self.centers, label, diff) self.add_update((self.centers,new_centers),x) result=tf.reduce_mean(tf.square(x[0] - centers_batch)) return result

shamangary commented 6 years ago

@longzeyilang 我其實是覺得如果你也同意Embedding的weight有update,某種程度上是跟之前的batch有關連的,不過我也同意你的說法是如果增加原作者那條直接讓center和delta_center的相加會有更好的效果,感謝你提供的資訊,學習了。

BTW,我對於add_update的理解如下 在官方的example有weight卻沒有寫add_update,那到底有沒有update, https://keras.io/layers/writing-your-own-keras-layers/ 或是Embedding layer沒有寫add_update那是沒有update weight嗎? 由於寫完單層後我們都會去call整個model,而keras model有包含network, https://github.com/keras-team/keras/blob/master/keras/engine/network.py#L765 所以在這就有add_update了,只是這裡的add_update就是很制式的一般layer update, 如果更特殊可能就要自己寫這樣,所以一般的weight想要update應該是不太需要寫add_update的, 我的理解是這樣,有錯請指正。