Open jabowery opened 1 year ago
Hi @jabowery ,
can you try using torch.cholesky_solve()
instead of torch.linalg.solve()
?
So, change the code from:
W = torch.linalg.solve(self.XTy,
self.XTX + self.lambda_reg * torch.eye(
self.XTX.size(0), device=self.XTX.device))[0].t()
to
W = torch.cholesky_solve(self.XTy,
self.XTX + self.lambda_reg * torch.eye(
self.XTX.size(0), device=self.XTX.device)).t()
Running the example/mnist.py I get: