awslabs / gluonts

Probabilistic time series modeling in Python
https://ts.gluon.ai
Apache License 2.0
4.57k stars 750 forks source link

Improve Transformer model. #698

Open AaronSpieler opened 4 years ago

AaronSpieler commented 4 years ago

Description

TODO:

The corresponding paper: https://arxiv.org/abs/1706.03762

kashif commented 4 years ago

@AaronSpieler this paper might be helpful too: https://papers.nips.cc/paper/8766-enhancing-the-locality-and-breaking-the-memory-bottleneck-of-transformer-on-time-series-forecasting.pdf

Callidior commented 4 years ago

While we are waiting for multi-layer transformers to be implemented in GluonTS, you can use this work-around to get a variant of TransformerEstimator supporting arbitrary stack sizes:

from gluonts.model import transformer

class MultilayerTransformerEncoder(transformer.trans_encoder.TransformerEncoder):

    def __init__(self, encoder_length, config, num_layers, **kwargs):
        if 'prefix' in kwargs:
            prefix = kwargs['prefix']
            del kwargs['prefix']
        else:
            prefix = ''

        super().__init__(encoder_length, config, prefix=prefix+'1_', **kwargs)

        self.layers = [transformer.trans_encoder.TransformerEncoder(
                encoder_length, config, prefix=prefix+str(i)+'_', **kwargs
            ) for i in range(2, num_layers+1)]
        for layer in self.layers:
            self.register_child(layer)

    def hybrid_forward(self, F, data, *args, **kwargs):

        data = super().hybrid_forward(F, data, *args, **kwargs)
        for layer in self.layers:
            data = layer.hybrid_forward(F, data, *args, **kwargs)
        return data

class MultilayerTransformerDecoder(transformer.trans_decoder.TransformerDecoder):

    def __init__(self, decoder_length, config, num_layers, **kwargs):
        if 'prefix' in kwargs:
            prefix = kwargs['prefix']
            del kwargs['prefix']
        else:
            prefix = ''

        super().__init__(decoder_length, config, prefix=prefix+'1_', **kwargs)

        self.layers = [transformer.trans_decoder.TransformerDecoder(
                decoder_length, config, prefix=prefix+str(i)+'_', **kwargs
            ) for i in range(2, num_layers+1)]
        for layer in self.layers:
            self.register_child(layer)

    def hybrid_forward(self, F, data, *args, **kwargs):

        data = super().hybrid_forward(F, data, *args, **kwargs)
        for layer in self.layers:
            data = layer.hybrid_forward(F, data, *args, **kwargs)
        return data

    def cache_reset(self):

        super().cache_reset()
        for layer in self.layers:
            layer.cache_reset()

class MultilayerTransformerEstimator(transformer.TransformerEstimator):

    def __init__(self, freq: str, prediction_length: int, num_layers: int=1, **kwargs):
        super().__init__(freq, prediction_length, **kwargs)

        assert (
            num_layers > 0
        ), "The value of `num_layers` should be > 0"

        self.encoder = MultilayerTransformerEncoder(
            self.context_length, self.config, num_layers, prefix='enc_'
        )

        self.decoder = MultilayerTransformerDecoder(
            self.context_length, self.config, num_layers, prefix='dec_'
        )

MultilayerTransformerEstimator takes the same arguments as TransformerEstimator, but provides an additional num_layers argument.