farizrahman4u / seq2seq

Sequence to Sequence Learning with Keras
GNU General Public License v2.0
3.17k stars 846 forks source link

output_length = 1 not working #242

Open sahil231090 opened 6 years ago

sahil231090 commented 6 years ago

I want to predict only a 1D time series and use MSE

input_length = 5
input_dim = 2

output_length = 1
output_dim = 1

samples = 100
hidden_dim = 24

x = np.random.random((samples, input_length, input_dim))
y = np.random.random((samples, output_length, output_dim))

model = SimpleSeq2Seq(output_dim=output_dim, hidden_dim=hidden_dim, output_length=output_length, input_shape=(input_length, input_dim), unroll=True)
model.compile(loss='mse', optimizer='sgd')
model.fit(x,y)

But if I had to predict the time series 2 steps ahead then it works

input_length = 5
input_dim = 2

output_length = 2
output_dim = 1

samples = 100
hidden_dim = 24

x = np.random.random((samples, input_length, input_dim))
y = np.random.random((samples, output_length, output_dim))

model = SimpleSeq2Seq(output_dim=output_dim, hidden_dim=hidden_dim, output_length=output_length, input_shape=(input_length, input_dim), unroll=True)
model.compile(loss='mse', optimizer='sgd')
model.fit(x,y)

Versions - Python: 2.7 keras: 2.0.6 theano: 1.0.1

Error Stack -

File "", line 15, in model.fit(x,y)

File "c:\pithon.shared\prod\400.106.1\lib\site-packages\keras\engine\training.py", line 1430, in fit initial_epoch=initial_epoch)

File "c:\pithon.shared\prod\400.106.1\lib\site-packages\keras\engine\training.py", line 1079, in _fit_loop outs = f(ins_batch)

File "c:\pithon.shared\prod\400.106.1\lib\site-packages\keras\backend\theano_backend.py", line 1197, in call return self.function(*inputs)

File "c:\pithon.shared\prod\400.106.1\lib\site-packages\theano\compile\function_module.py", line 917, in call storage_map=getattr(self.fn, 'storage_map', None))

File "c:\pithon.shared\prod\400.106.1\lib\site-packages\theano\gof\link.py", line 325, in raise_with_op reraise(exc_type, exc_value, exc_trace)

File "c:\pithon.shared\prod\400.106.1\lib\site-packages\theano\compile\function_module.py", line 903, in call self.fn() if output_subset is None else\

ValueError: Input dimension mis-match. (input[0].shape[2] = 32, input[1].shape[2] = 1) Apply node that caused the error: Elemwise{sub,no_inplace}(Elemwise{mul,no_inplace}.0, /recurrent_sequential_78_target) Toposort index: 205 Inputs types: [TensorType(float32, (True, False, False)), TensorType(float32, 3D)] Inputs shapes: [(1L, 1L, 32L), (32L, 1L, 1L)] Inputs strides: [(128L, 128L, 4L), (4L, 4L, 4L)] Inputs values: ['not shown', 'not shown'] Outputs clients: [[Elemwise{mul,no_inplace}(Elemwise{Composite{(((i0 * i1) / i2) / i3)}}.0, Elemwise{sub,no_inplace}.0), Elemwise{Sqr}(0, 0)]]

Backtrace when the node is created(use Theano flag traceback.limit=N to make it longer): File "c:\pithon.shared\prod\400.106.1\lib\site-packages\ipykernel\zmqshell.py", line 501, in run_cell return super(ZMQInteractiveShell, self).run_cell(*args, **kwargs) File "c:\pithon.shared\prod\400.106.1\lib\site-packages\IPython\core\interactiveshell.py", line 2717, in run_cell interactivity=interactivity, compiler=compiler, result=result) File "c:\pithon.shared\prod\400.106.1\lib\site-packages\IPython\core\interactiveshell.py", line 2821, in run_ast_nodes if self.run_code(code, result): File "c:\pithon.shared\prod\400.106.1\lib\site-packages\IPython\core\interactiveshell.py", line 2881, in run_code exec(code_obj, self.user_global_ns, self.user_ns) File "", line 14, in model.compile(loss='mse', optimizer='sgd') File "c:\pithon.shared\prod\400.106.1\lib\site-packages\keras\engine\training.py", line 840, in compile sample_weight, mask) File "c:\pithon.shared\prod\400.106.1\lib\site-packages\keras\engine\training.py", line 446, in weighted score_array = fn(y_true, y_pred) File "c:\pithon.shared\prod\400.106.1\lib\site-packages\keras\losses.py", line 8, in mean_squared_error return K.mean(K.square(y_pred - y_true), axis=-1)

HINT: Use the Theano flag 'exception_verbosity=high' for a debugprint and storage map footprint of this apply node.