gzerveas / mvts_transformer

Multivariate Time Series Transformer, public version
MIT License
752 stars 173 forks source link

obtain embeddings from trained models #14

Open gonlairo opened 2 years ago

gonlairo commented 2 years ago

Hello! Thank you for sharing the code of the paper!

I want to know if there is an easy way of extracting the embeddings (z_t) of a trained model. I was able to pre-train the model (unsupervised learning through input masking) but after I obtain the .pth files, I am struggling to obtain the embedding of the dataset.

gzerveas commented 2 years ago

Hi, I presume you are interested in obtaining static/precomputed embeddings for the entire dataset and store them somewhere for future use. This feature has not been implemented. Instead, the approach that is available right now in the code is to load the model with frozen parameters, and allow only the output layer to be trainable. This is equivalent to the former (if your objective e.g. would be do train a linear classifier on top of embeddings), but will compute embeddings on the fly.

However, it is simple to write the code for extracting and storing embeddings: you would load the model, use a loop over an initialized dataloader for your dataset of interest, and add the last hidden layer activations to the model's outputs. You would then use e.g. torch.save or numpy.save to write the corresponding tensors/ numpy arrays of the accumulated embeddings to disk. You could also write a simple function to aggregate for each sample z_t over all t (e.g. average or concatenate).

To reuse all existing code machinery, and avoid an extra iteration over the dataset, I would personally do it by modifying the evaluate function of class UnsupervisedRunner to do exactly that: accumulate the model's outputs (this is done already) but now also accumulate embeddings (see below how) when given a new, additional argument extract_embeddings. You could use this argument e.g. at the last epoch's model evaluation/validation. E.g. in line 263 of main.py you could do something like: aggr_metrics_val, best_metrics, best_value, embeddings = validate(val_evaluator, tensorboard_writer, config, best_metrics, best_value, epoch, extract_embeddings=(epoch == config["epochs"])), and finally write embeddings to disk. If instead you want to do this extraction as a separate operation (i.e. not at the end of pre-training, but loading an already available model checkpoint), you would get the embeddings of whatever dataset you define as "test set" (through the flag --test_pattern) simply by running the main.py script with the the flag --test_only. If you don't want to get evaluation metrics (or don't have labels for your dataset), then I would replicate UnsupervisedRunner.evaluate into a new method UnsupervisedRunner.extract_embeddings , which would be identical but would not compute loss, metrics etc, instead only doing the feature extraction (see below how). Just remember that we don't need to compute the gradients when computing the embeddings, so all the above should be wrapped by a with torch.no_grad(): context, to save memory and time. For example, assuming that you have written a UnsupervisedRunner.extract_embeddings(self) member function which is a replica of UnsupervisedRunner.evaluate(self, epoch_num=None, keep_all=True) but without the loss computation and metrics calculation part, you would need to add this somewhere in the main.py:

embeddings_extractor = UnsupervisedRunner(model, loader, device, loss_module=None)
with torch.no_grad():
    embeddings = embeddings_extractor.extract_embeddings()

So how can you get the final embeddings of the transformer encoder? You can simply get them right before or after line 242 (i.e., before applying dropout) for the TSTransformerEncoder model that has undergone unsupervised pre-training, and after line 304 for the model fine-tuned for classification or regression. You can simply return this tensor as a second output of the model's forward function, e.g. called embeddings, or better still define another member function get_embeddings, which will be identical to the forward, without the last few lines, and would only return the embeddings. This would be called by the evaluate or extract_embeddings function I mentioned above, which would apply the get_embeddings function to your desired dataset.

I am currently engaged in so many other projects, that I can't really tell when I would have time to implement this, however. It would be great if you could implement this, and submit a pull request.