sintel-dev / Orion

Library for detecting anomalies in signals
https://sintel.dev/Orion/
MIT License
1.05k stars 162 forks source link

How to interpret (negative) eg_loss values in TadGAN training #441

Closed nunobv closed 1 year ago

nunobv commented 1 year ago

Just reviving the question/issue in #165 to address some doubts of mine:

sarahmish commented 1 year ago

Hi @nunobv, thank you for opening the issue

For a better understanding of the loss, I recommend setting detailed_losses=True. You can do that by adding it to the hyperparameter settings as follows:

hyperparameters = {
    "orion.primitives.tadgan.TadGAN#1": {
        "detailed_losses": True
    }
}

This will present you with 4 values to each network in the model:

  1. for cx critic we have cx_loss, cx_real, cx_fake, cx_gp respectively.
    • cx_loss is the total loss value
    • cx_real is the critic score of a real sample cx(x)
    • cx_fake is the critic score of a fake generated sample cx(g(z))
    • cx_gp is gradient penalty regularization term.
  2. for cz critic we have similar values but for z; the latent variable.
  3. for eg encoder-generator we have eg_loss, eg_cx_fake, eg_cz_fake, eg_mse respectively.
    • eg_loss is the total loss value
    • eg_cx_fake is the critic score for a fake generated x; meaning g(z)
    • eg_cz_fake is the critic score for an encoded x; meaning e(x).
    • eg_mse is the mean squared error between a real x and a generated x; meaning x and g(e(x)).

Now regarding your points, I first want to note that critic scores are not bounded given the wasserstein distance. Because eg_loss is composed of multiple values, it can be negative because of cx_loss. If you monitor eg_mse alone, it will provide a better intuition on reconstruction performance. Since the critic score is not bounded, this will ultimately highly influence eg_loss and you will see negative values. With regards to cz_loss, it will oscillate but not exactly around zero since there is no bound.

I put an example here for S-1 for 5 epochs in case it makes it easier to interpret what I mentioned above. The easiest way for me to monitor the loss is by looking at eg_loss entries since it summarizes the critics and reconstruction loss.

Epoch: 1/5, Losses: {
    'cx_loss': array([-3.4979, -6.7901,  0.5796,  0.2713]), 
    'cz_loss': array([ 4.7112, -0.4569,  4.7448,  0.0423]), 
    'eg_loss': array([ 2.0347, -0.5623, -4.3337,  0.6931])}
Epoch: 2/5, Losses: {
    'cx_loss': array([ -8.109 , -11.1804,   1.8813,   0.119 ]), 
    'cz_loss': array([10.9778, -0.3971,  9.7012,  0.1674]), 
    'eg_loss': array([-4.0139, -1.8051, -6.945 ,  0.4736])}
Epoch: 3/5, Losses: {
    'cx_loss': array([ -6.7245, -12.7957,   4.9852,   0.1086]), 
    'cz_loss': array([-19.2429,  -0.2567, -19.7197,   0.0734]), 
    'eg_loss': array([26.7805, -4.8654, 26.9848,  0.4661])}
Epoch: 4/5, Losses: {
    'cx_loss': array([ -7.7752, -13.9027,   4.81  ,   0.1318]), 
    'cz_loss': array([-14.144 ,  -0.3429, -14.2931,   0.0492]), 
    'eg_loss': array([15.6991, -4.7416, 15.6796,  0.4761])}
Epoch: 5/5, Losses: {
    'cx_loss': array([ -9.7054, -15.2037,   4.0292,   0.1469]), 
    'cz_loss': array([ 0.2413, -0.5112, -1.0699,  0.1822]), 
    'eg_loss': array([ 2.1054, -4.3083,  1.7519,  0.4662])}

Let me know if you have any further questions.

nunobv commented 1 year ago

Thank you very much for your incredibly detailed answer @sarahmish. Crystal clear, as always.

This leads me to one last thing that "intrigues" me: I understand that the eg_mse (reconstruction) loss is given a weight of 10 compared to 1 given to both eg_cx_fake and eg_cz_fake. And I infer that this is an attempt at trying to balance each loss contribution, which in turn has to due with the unbounded critic scores as you've mentioned.

However, besides not being bounded (which even with the unbalanced weights could lead to a eg_loss leaning heavily towards the critic scores), the critic scores are not on the same "dimension" (for lack of a better term) as the reconstruction error: the former resembles a classification error, while the latter is a distance measure. So we are mixing different "realms" in the same loss function. Plus, even if you give a weight of 10 to the reconstruction loss, as the model gets better at reconstructing the signal the more heavily the loss function is impacted by the critic scores.

I understand this particular setup has yielded the best overall results for the experimented datasets. I guess my question is: isn't it too situational and empirical? Do you think this is one parameter/setting that could highly influence the performance of TadGAN in particular applications/signals or it won´t have much influence?

Sorry for the loooong introduction! Cheers!

sarahmish commented 1 year ago

Definitely @nunobv! This is based on experimentation mostly.

critic score and MSE are not on the same scale as you mentioned. MSE gets us a reconstructed signal that looks close to the original one while you can think of the critic score as a regularization term. It helps in generating a signal without overfitting to anomalies.

You can experiment with the weights and see how it impacts the reconstruction. Based on what I recall, if you decrease 10 it wouldn't reconstruct well.