zurutech / anomaly-toolbox

Anomaly detection using GANs.
MIT License
90 stars 12 forks source link

generator_bce input parameter #8

Closed fer927 closed 4 months ago

fer927 commented 4 months ago

Description

Isn't the generator loss supposed to be computed based on the discriminator output of the regenerated image? Therefore shouldn't it be d_gex instead of g_ex as the input to bce_g_loss? Below is from trainers ganomaly.py

Discriminator on the reconstructed real data g_ex dgex, = self.discriminator(inputs=g_ex, training=True)

Encode the reconstructed real data g_ex e_gex = self.encoder(g_ex, training=True)

Discriminator Loss d_loss = self._minmax(d_x_features, d_gex_features) d_loss = self._minmax(d_x, d_gex)

Generator Loss adversarial_loss = losses.adversarial_loss_fm(d_f_x, d_f_x_hat) bce_g_loss = generator_bce(g_ex, from_logits=True)

What I Did

Paste the command(s) you ran and the output.
If there was a crash, please include the traceback here.
iLeW commented 4 months ago

Hello @fer927,

thank you for your detailed observation and for bringing this to our attention.

Issue Summary

You pointed out that in the current implementation of ganomaly.py, the generator's Binary Cross-Entropy (BCE) loss is computed as follows:

bce_g_loss = generator_bce(g_ex, from_logits=True)

You suggested that it should instead be:

bce_g_loss = generator_bce(d_gex, from_logits=True)

Evaluation of the Issue

After reviewing the code and the original GANomaly implementation, we can confirm that your observation seems indeed correct. The generator loss should incorporate the discriminator's feedback to ensure the generator improves its ability to produce realistic images.

Details

In GAN training, the generator's adversarial loss typically uses the discriminator's output on the generated data. This helps guide the generator to create data that can fool the discriminator.

In the original GANomaly implementation (model.py), the adversarial loss for the generator is computed based on the discriminator’s feedback, as shown below:

self.err_g_adv = self.l_adv(self.netd(self.input)[1], self.netd(self.fake)[1])

This indicates that the generator loss (err_g_adv) is computed using the discriminator's features (self.netd(self.fake)[1]).

Considerations

During the development of the anomaly toolbox, we encountered various challenges that might have led to the current implementation. Typically, when we deviate from standard practices, we leave comments explaining our rationale. However, in this case, there is no such comment, which suggests that your proposed change is likely correct.

Proposed Change

To align with the standard GAN training practices and the original GANomaly model, the BCE loss for the generator in ganomaly.py should be computed using d_gex, the discriminator's output on the reconstructed data:

bce_g_loss = generator_bce(d_gex, from_logits=True)

For further reference, please see the original GANomaly paper here.

Conclusion

Your suggestion appears to be correct, and we appreciate your attention to detail. We will update the code to reflect this correction. Thank you for helping us improve the quality of our project!