theislab / trVAE

Conditional out-of-distribution prediction
MIT License
54 stars 11 forks source link

Question about training #18

Open xiaoyanLi629 opened 3 years ago

xiaoyanLi629 commented 3 years ago

Hello,.

I have a question about how the model was trained? I saw two losses in the compile_models function in the _trave.py file. I also saw two outputs in the model, reconstruction_output and mmd_output. It looks like to me these two outputs are parallel.

Could you explain a little bit how the training works? Are these losses alternated minimized? Are these two losses correlated to each other (minimize one loss will affect the other one)?

    decoder_inputs = [encoder_outputs, self.decoder_labels]

    decoder_outputs = self.decoder_model(decoder_inputs)
    decoder_mmd_outputs = self.decoder_mmd_model(decoder_inputs)

    reconstruction_output = Lambda(lambda x: x, name="reconstruction")(decoder_outputs)
    mmd_output = Lambda(lambda x: x, name="mmd")(decoder_mmd_outputs)

    self.cvae_model.compile(optimizer=optimizer,
                            loss=[loss, mmd_loss],
                            metrics={self.cvae_model.outputs[0].name: loss,
                                     self.cvae_model.outputs[1].name: mmd_loss}

Thank you!

bhomass commented 1 year ago

I am a user just like you, but I can answer your question. It is all explained in the referenced paper. There is the usual VAE loss function involving reconstruction and KLD for the variational vs z prior distribution. Then there is the MMD loss, having to do with the desire to separate the hidden distribution for samples with different conditions.