LongxingTan / Time-series-prediction

tfts: Time Series Deep Learning Models in TensorFlow
https://time-series-prediction.readthedocs.io/en/latest/
MIT License
821 stars 165 forks source link

AttributeError: 'Encoder' object has no attribute 'rnn_type' #35

Open forestbat opened 1 year ago

forestbat commented 1 year ago

❔Question

我初始化模型:

model = AutoModel('seq2seq', custom_model_params={
            "rnn_type": "lstm",
            "bi_direction": False,
            "rnn_size": 64,
            "dense_size": 64,
            "num_stacked_layers": 1,
            "scheduler_sampling": 0,  # teacher forcing
            "use_attention": False,
            "attention_sizes": 64,
            "attention_heads": 2,
            "attention_dropout": 0,
            "skip_connect_circle": False,
            "skip_connect_mean": False,
        })

但是在运行时报错:

test_seq2seq.py:32: 
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
C:\ProgramData\Anaconda3\envs\my2ndconda\lib\site-packages\tfts\models\auto_model.py:35: in __init__
    self.model = Seq2seq(predict_sequence_length=predict_length, custom_model_params=custom_model_params)
C:\ProgramData\Anaconda3\envs\my2ndconda\lib\site-packages\tfts\models\seq2seq.py:44: in __init__
    self.encoder = Encoder(
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _

self = <tfts.models.seq2seq.Encoder object at 0x0000026FBCA68CA0>
rnn_type = 'lstm', rnn_size = 64, rnn_dropout = 0, dense_size = 64, kwargs = {}

    def __init__(self, rnn_type, rnn_size, rnn_dropout=0, dense_size=32, **kwargs):
        super(Encoder, self).__init__(**kwargs)
        if rnn_type.lower() == "gru":
            self.rnn = GRU(
                units=rnn_size, activation="tanh", return_state=True, return_sequences=True, dropout=rnn_dropout
            )
>       elif self.rnn_type.lower() == "lstm":
E       AttributeError: 'Encoder' object has no attribute 'rnn_type'

C:\ProgramData\Anaconda3\envs\my2ndconda\lib\site-packages\tfts\models\seq2seq.py:112: AttributeError

这是怎么回事?tfts版本为0.0.6.

Additional context

LongxingTan commented 1 year ago

@forestbat

可以试试更新一下 到0.0.7