nschaetti / EchoTorch

A Python toolkit for Reservoir Computing and Echo State Network experimentation based on pyTorch. EchoTorch is the only Python module available to easily create Deep Reservoir Computing models.
GNU General Public License v3.0
441 stars 117 forks source link

Algorithm breaks when using cuda #13

Closed papalotis closed 3 years ago

papalotis commented 5 years ago

The ESN Algo breaks (throws an exception) when it is run with a GPU. More specifically the function finalize in /echotorch/nn/RRCell.py.

RuntimeError                              Traceback (most recent call last)
~/bachelor/test.py in <module>
     69 # Finalize training
---> 70 esn.finalize()
     72 # Train MSE

~/libs/EchoTorch/echotorch/nn/ESN.py in finalize(self)
    187         """
    188         # Finalize output training
--> 189         self.output.finalize()
    191         # Not in training mode anymore

~/libs/EchoTorch/echotorch/nn/RRCell.py in finalize(self, train)
    179                 # Algo
--> 180                 ridge_xTx = self.xTx + self.ridge_param * torch.eye(self.input_dim + self.with_bias, dtype=self.dtype)
    181                 inv_xTx = ridge_xTx.inverse()
    182                 self.w_out.data = torch.mm(inv_xTx, self.xTy).data

RuntimeError: expected type torch.cuda.FloatTensor but got torch.FloatTensor

This error was created by running the Mackey Glass example The workaround that I have for now is

 # Algo
algo_eye = torch.eye(self.input_dim + self.with_bias, dtype=self.dtype)

#convert to cuda if necessary
if self.xTx.is_cuda:
    algo_eye = algo_eye.cuda()

ridge_xTx = self.xTx + self.ridge_param * algo_eye

Definitely not the best solution but works for now

MrMonotreme commented 4 years ago

Yes, I had a similar issue. Could a commit be scheduled to remedy this problem? I attempted using a global change to the gpu device, but was unsuccessful in changing the device of a resident tensor in the EchoTorch library. I believe @papalotis' solution will work best for my ends.

nschaetti commented 3 years ago

Solved in 6159b46ed280df738b10e0aa3d290dc4314b883f (22/01/2021), dev branch No tests with NARMA-10 with CUDA