Prinsphield / ELEGANT

ELEGANT: Exchanging Latent Encodings with GAN for Transferring Multiple Face Attributes
https://arxiv.org/abs/1803.10562
MIT License
263 stars 42 forks source link

一個小問題 #24

Open tom99763 opened 2 years ago

tom99763 commented 2 years ago

你把原來的z做concat是為了更新encoder吧? 畢竟swap元素這個操作不可微分。 那你當初在搭這個模型的時候有試過straight through estimatior嗎?

甘溫

tom99763 commented 2 years ago

順便分享一下我有在嘗試用tensorflow跑這個模型的實驗

swap跟剛剛講的straight through estimator我是這樣做

def get_idx(z, y):
        s = [tf.range(z.shape[i]) for i in range(3)]
        d1, d2 ,d3 = tf.meshgrid(s[1], s[0], s[2])
        idx=tf.stack([d2, d1, d3], axis=-1)
        _, h, w, _ = idx.shape
        y=tf.repeat(tf.repeat(y[:, None, None], h, axis=1), w, axis=2)[...,None]
        idx = tf.concat([idx , y], axis=-1)
        return idx

def get_corr_ele(z, y1, y2):
        idx1=get_idx(z, y1)
        idx2=get_idx(z, y2)
        idx = tf.concat([idx1, idx2], axis=0)
        ele=tf.gather_nd(z,idx)
        return ele, idx

def swap(z1, z2, y1, y2):
        z1y, idx=self.get_corr_ele(z1, y1, y2)
        z2y, _=self.get_corr_ele(z2, y1, y2)
        z12 = tf.tensor_scatter_nd_update(z1, idx, z2y)
        #z21 = tf.tensor_scatter_nd_update(z2, idx, x1y)

        #straight throguh estimator
        z12 = z1 + tf.stop_gradient(z12-z1)
        #z21 = z2 + tf.stop_gradient(z21-z2)
        z12 = tf.concat([z12, z1], axis=-1)
        return z12 #, z21