stefanonardo / pytorch-esn

An Echo State Network module for PyTorch.
MIT License
205 stars 43 forks source link

mnist: RuntimeError: linalg.solve: A must be batches of square matrices, but they are 501 by 10 matrices #17

Open jabowery opened 1 year ago

jabowery commented 1 year ago

Running the example/mnist.py I get:

/torchesn/nn/echo_state_network.py", line 237, in fit
    W = torch.linalg.solve(self.XTy,
RuntimeError: linalg.solve: A must be batches of square matrices, but they are 501 by 10 matrices
stefanonardo commented 1 year ago

Hi @jabowery , can you try using torch.cholesky_solve() instead of torch.linalg.solve()?

So, change the code from:

W = torch.linalg.solve(self.XTy,
                           self.XTX + self.lambda_reg * torch.eye(
                               self.XTX.size(0), device=self.XTX.device))[0].t()

to

W = torch.cholesky_solve(self.XTy,
                           self.XTX + self.lambda_reg * torch.eye(
                               self.XTX.size(0), device=self.XTX.device)).t()