unit8co / darts

A python library for user-friendly forecasting and anomaly detection on time series.
https://unit8co.github.io/darts/
Apache License 2.0
7.63k stars 831 forks source link

Refactor Transformer model #601

Open hrzn opened 2 years ago

hrzn commented 2 years ago

Currently the Transformer is not really implemented as it should. We should revisit to implement it like the in original Transformer paper; including always training for predicting next sample (like language models), and calling the encoder+decoder in auto-regressive ways when producing forecasts. See: Attention Is All You Need

Note from @pennfranc : This current implementation is fully functional and can already produce some good predictions. However, it is still limited in how it uses the Transformer architecture because the tgt input of torch.nn.Transformer is not utlized to its full extent. Currently, we simply pass the last value of the src input to tgt. To get closer to the way the Transformer is usually used in language models, we should allow the model to consume its own output as part of the tgt argument, such that when predicting sequences of values, the input to the tgt argument would grow as outputs of the transformer model would be added to it. Of course, the training of the model would have to be adapted accordingly.

hrzn commented 1 year ago

See also: https://github.com/unit8co/darts/issues/672

JanFidor commented 1 year ago

Hi @dennisbader @madtoinou , while working on the RWKV PR I realized, that I'm not using teacher forcing during training which would hinder the training quite a bit. It's a big part of this issue so I wanted to ask if I could pick it up, so that I had a point reference of how its final implementation should look like if I get it merged (+ the issue looks really cool ;) )

dennisbader commented 1 year ago

Hi @JanFidor, of course :) We would be happy about your contribution to improve the Transformer model 🚀

Let us know if you need any assistance