Open AaronSpieler opened 4 years ago
@AaronSpieler this paper might be helpful too: https://papers.nips.cc/paper/8766-enhancing-the-locality-and-breaking-the-memory-bottleneck-of-transformer-on-time-series-forecasting.pdf
While we are waiting for multi-layer transformers to be implemented in GluonTS, you can use this work-around to get a variant of TransformerEstimator supporting arbitrary stack sizes:
from gluonts.model import transformer
class MultilayerTransformerEncoder(transformer.trans_encoder.TransformerEncoder):
def __init__(self, encoder_length, config, num_layers, **kwargs):
if 'prefix' in kwargs:
prefix = kwargs['prefix']
del kwargs['prefix']
else:
prefix = ''
super().__init__(encoder_length, config, prefix=prefix+'1_', **kwargs)
self.layers = [transformer.trans_encoder.TransformerEncoder(
encoder_length, config, prefix=prefix+str(i)+'_', **kwargs
) for i in range(2, num_layers+1)]
for layer in self.layers:
self.register_child(layer)
def hybrid_forward(self, F, data, *args, **kwargs):
data = super().hybrid_forward(F, data, *args, **kwargs)
for layer in self.layers:
data = layer.hybrid_forward(F, data, *args, **kwargs)
return data
class MultilayerTransformerDecoder(transformer.trans_decoder.TransformerDecoder):
def __init__(self, decoder_length, config, num_layers, **kwargs):
if 'prefix' in kwargs:
prefix = kwargs['prefix']
del kwargs['prefix']
else:
prefix = ''
super().__init__(decoder_length, config, prefix=prefix+'1_', **kwargs)
self.layers = [transformer.trans_decoder.TransformerDecoder(
decoder_length, config, prefix=prefix+str(i)+'_', **kwargs
) for i in range(2, num_layers+1)]
for layer in self.layers:
self.register_child(layer)
def hybrid_forward(self, F, data, *args, **kwargs):
data = super().hybrid_forward(F, data, *args, **kwargs)
for layer in self.layers:
data = layer.hybrid_forward(F, data, *args, **kwargs)
return data
def cache_reset(self):
super().cache_reset()
for layer in self.layers:
layer.cache_reset()
class MultilayerTransformerEstimator(transformer.TransformerEstimator):
def __init__(self, freq: str, prediction_length: int, num_layers: int=1, **kwargs):
super().__init__(freq, prediction_length, **kwargs)
assert (
num_layers > 0
), "The value of `num_layers` should be > 0"
self.encoder = MultilayerTransformerEncoder(
self.context_length, self.config, num_layers, prefix='enc_'
)
self.decoder = MultilayerTransformerDecoder(
self.context_length, self.config, num_layers, prefix='dec_'
)
MultilayerTransformerEstimator
takes the same arguments as TransformerEstimator
, but provides an additional num_layers
argument.
Description
N=1
TODO:
The corresponding paper: https://arxiv.org/abs/1706.03762