LongxingTan / Data-competitions

My Data Competition Solutions
101 stars 43 forks source link

只有一张表情况下怎么指定验证集? #5

Closed forestbat closed 1 year ago

forestbat commented 1 year ago

首先感谢作者的工作。 我有这样的一个CSV数据集:

TM,Z
1,97.92
2,97.92
3,97.92
4,97.93
5,97.93
6,97.93
7,97.93
8,97.94
……

现在我准备用seq2seq预测以后的Z值,步骤如下:

test_df = pd.read_csv('my_table.csv')
X_train, X_test, Y_train, Y_test = train_test_split(test_df.index, stcd_df['Z'], test_size=0.25)
train_length = len(X_train)
predict_length = len(X_test)
model = AutoModel('seq2seq', predict_length)
trainer = KerasTrainer(model)
X_train_dim = np.expand_dims(X_train.values, axis=(0, -1))
X_test_dim = np.expand_dims(X_test.values, axis=(0, -1))
Y_train_dim = np.expand_dims(Y_train.values, axis=(0, -1))
Y_test_dim = np.expand_dims(X_test.values, axis=(0, -1))
trainer.train(train_dataset=(X_train_dim, Y_train_dim), valid_dataset=?)

如果将valid_dataset设置为None会报错(虽然注释里写着valid_dataset可以为None),但如果将valid_dataset设置为(X_train_dim, Y_train_dim)或者(X_test_dim, Y_test_dim),都会报错: Incompatible shapes: [1,5081,1] vs. [1,15240,1] [Op:SquaredDifference](在本例中,train_length=15240,predict_length=5081) 请问我该怎么填valid_dataset

forestbat commented 1 year ago

发错地方了,不好意思