nairouz / DynAE

GNU General Public License v3.0
30 stars 11 forks source link

Critic weights are not modified during pre-training #16

Closed m-schier closed 3 years ago

m-schier commented 3 years ago

Hi,

it appears to me there is an issue during training in the published version of the code leading to the critic/discriminator weights not being trained. In function aci_ae_constructor in file DynAE.py the weights of the critic are set as non-trainable, because the critic should not be trained when the auto encoder is trained: critic.trainable = False. However, this change also propagates to the critic and discriminator model, such that DynAE.train_on_batch_disc() does not update any weights.

This can be verified in two ways, either checking whether the critic weights are updated by inserting into DynAE.py line 333 following:

        last_critic_weights = None

        #Training loop
        for ite in range(int(maxiter)):
            #Validation interval
            if ite % validate_interval == 0:
                # Check weights modified
                critic_weights = self.disc.get_weights()
                if last_critic_weights is not None:
                    if all([np.all(critic_weights[i] == last_critic_weights[i]) for i in range(len(critic_weights))]):
                        print("Critic weights were not modified during training")
                last_critic_weights = [np.copy(c) for c in critic_weights]

or by directly checking whether the discriminator model has any trainable weights, i.e. modifying:

    def train_on_batch_disc(self, x1, x2, y1, y2):
        y = np.zeros((x1.shape[0],))
        if len(self.disc.trainable_weights) == 0:
            print("The discriminator is being trained with no trainable weights")
        return self.disc.train_on_batch([x1, x2, y1, y2], y)

It is suprising to me how well the model performs despite the discriminator not being trained, but it seems to me that your published work does things differently. Is this a mistake?

Kind regards

nairouz commented 3 years ago

Hi,

Thank you for your interest. I think my implementation was based on this recommendation: https://stackoverflow.com/questions/53700833/keras-trainable-scope

I remember that I was able to get 96% on MNIST when optimizing the adversarially constrained interpolation objective without data augmentation. However, I am not able to get similar results any longer. I suspect that the recent versions of tensorflow are not compatible with my code or probably I have made some mistakes when cleaning the code.

The fact that the model performs very well despite the discriminator not being trained is probably because of data augmentation. Pretraining based on ACI does not provide any improvement when data augmentation is available. That's what I observed.

For now, you can neglect this issue as it does not have any impact on the clustering phase. In the future, I will manage time for solving this problem.

m-schier commented 3 years ago

Thank you very much for your quick response. Your linked own SO question was quite helpful, indeed when moving the call to compile_disc before the call to aci_ae_constructor the weights of the critic are updated during training, but the achieved accuracy is worse, so there is something else not quite right. Anyways if you state that the good accuracy of the pretraining is caused by the augmentation used, I will look into that, as I was primarily interested in the results of the pretraining phase.