ML4ITS / mtad-gat-pytorch

PyTorch implementation of MTAD-GAT (Multivariate Time-Series Anomaly Detection via Graph Attention Networks) by Zhao et. al (2020, https://arxiv.org/abs/2009.02040).
MIT License
328 stars 76 forks source link

A Question about the implementation. #2

Closed Kevin-XiongC closed 3 years ago

Kevin-XiongC commented 3 years ago

Thanks for making this repo public. I have some questions after reading your code.

https://github.com/ML4ITS/mtad-gat-pytorch/blob/8f907ccef2695252a20db6d32536cc07f3fc53e8/training.py#L123 The paper used a VAE-like method for the reconstructor but you simply used MSE like naive autoencoder. I wonder whether this is because of the stability of optimization. In my case, I tried to sample at every timestep like the LSTM-VAE model and sometimes the loss just became nan.

axeloh commented 3 years ago

Hi,

In our case, we wanted to start with a simpler reconstruction model, so we chose a GRU-based "decoder" (decoder as the last hidden state of the GRU layer is extracted and used as the encoding that the recon model decodes into the original input), and it turned out to work quite well. We might implement a VAE as the recon model at a later time.

To your problem of nan loss, have you tried to clip gradients and/or changing learning rate?

Kevin-XiongC commented 3 years ago

Hi,

Thanks for the reply. I've tried gradient clipping (between -1 and 1) and a small learning rate [1e-5,1e-4]. It still didn't work out. Maybe it's because I am using my own dataset. The difficulty for me is that the paper does not explain what their recon model is like, and there exist many approaches to reconstruct the original sequence from the encoder's outputs. About the VAE part, another re-implementation treats the whole sequence as a N*T random variable and parameterizes the posteriors for encoder P(Z|X) and decoder P(X|Z) respectively. For me, I only use distribution and sample at the encoder and the numerical issue no longer exists. Hope it will be helpful for your later work.

Anyway, another interesting thing is about the sliding window. I saw most of the time series AD models using a sliding window with stride==1, which may increase the possibility of overfitting because samples share overlaps with each other. I know it's something more general and beyond the paper. But I am working on an industry system and the leader warned me to avoid using stride==1 sliding windows because you would leak the label to the model.

axeloh commented 3 years ago

I see. I do not know too much about VAEs at this point, but I recall that because it outputs mu and sigma and sample from N(mu, sigma), the sigma must be positive, which one can ensure by viewing it as the model is outputting log(sigma) and using sigma = e^[log(sigma)] > 0. If you are not doing this, could it be creating nan values?

Thanks for your input on VAE. And I agree, the paper is quite unclear in both what is the output from the GRU layer and how the VAE is implemented... As mentioned earlier, we choose to extract the last hidden state (belonging to the last timestamp) as the output of the GRU layer. From this point of view, the path going from input data and through the reconstruction model constitutes an auto-encoder, where the input data has been encoded into latent variables of GRU layer hid dim, and then decoded by the recon model. Thus, the original input is encoded from dimension nxd (n = number of input timestamps, d = dimension of input) to the hidden dim of the GRU. Lets say n=100, d=20, and hidden dim of gru is 150, then the dimension is reduced from 100 * 20 = 2k to 150, before the recon model decodes it back to nxd.

Also note that the same output of the GRU layer is also fed to the forecasting model, enforcing the output of the GRU layer to contain "info" useful both for forecasting the next value and reconstructing the complete input window.

I have not heard this argument about overfitting because of low stride before. It is not clear to me how this can actually be the case, as each sample (window) in the batch is run in parallel and there are no way for the model to "see" between the samples (?). However, if it were the case, shuffling the samples within each batch would prevent this, no?