nmheim / torsk

An echo state network (ESN) for video prediction
10 stars 5 forks source link

Convolutional ESN #25

Closed nmheim closed 5 years ago

nmheim commented 5 years ago

Implement a convolutional input layer

nmheim commented 5 years ago

@jamesavery the ConvESN seems to be working in theory, but I get weird error when starting to predict:

DEBUG:torsk:Creating 1200 training states
DEBUG:torsk:Optimizing output weights
torch.Size([1100, 900]) torch.Size([1100, 1225])
DEBUG:torsk:Predicting the next 300 frames
---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
~/repos/toto/experiments/ocean/conv_run.py in <module>
     38
     39 logger.info("Training + predicting ...")
---> 40 model, outputs, pred_labels, _ = torsk.train_predict_esn(model, loader, params)
     41
     42 # weight = model.esn_cell.res_weight._values().numpy()

~/repos/toto/torsk/__init__.py in train_predict_esn(model, loader, params, outdir)
     80     init_inputs = labels[-1]
     81     outputs, out_states = model.predict(
---> 82         init_inputs, states[-1], nr_predictions=params.pred_length)
     83
     84     if outdir is not None:

~/repos/toto/torsk/models/conv_esn.py in predict(self, initial_inputs, initial_state, nr_predictions)
    147
    148         for ii in range(nr_predictions):
--> 149             state = self.esn_cell(inp, state)
    150             ext_state = torch.cat([self.ones, inp, state], dim=1)
    151             output = self.out(ext_state)

~/anaconda3/envs/py3/lib/python3.6/site-packages/torch/nn/modules/module.py in __call__(self, *input, **kwargs)
    487             result = self._slow_forward(*input, **kwargs)
    488         else:
--> 489             result = self.forward(*input, **kwargs)
    490         for hook in self._forward_hooks.values():
    491             hook_result = hook(self, input, result)

~/repos/toto/torsk/models/conv_esn.py in forward(self, inputs, state)
     78         x_inputs = []
     79         for filt in self.filters:
---> 80             conv = F.conv2d(inputs, filt)
     81             x_inputs.append(conv.reshape(-1))
     82         x_inputs = torch.cat(x_inputs, dim=0).unsqueeze(-1)

RuntimeError: $ Torch: invalid memory size -- maybe an overflow? at /opt/conda/conda-bld/pytorch_1544174967633/work/aten/src/TH/THGeneral.cpp:188

I heavily doubt that its a memory thing. maybe we'd have to switch to a scipy convolution... do you have any ideas?