adavoudi / spdnet

Implementation of Deep SPDNet in pytorch
MIT License
58 stars 11 forks source link

RuntimeError:_th_diag not supported on CUDAType for Bool #11

Open yxbu opened 4 years ago

yxbu commented 4 years ago

I get a runtime error when I run demo.py. the version of my Pytorch is 1.2.0.what's the problem? 1747 371

Epoch: 1 0it [00:00, ?it/s]Traceback (most recent call last): File "demo.py", line 133, in train_loss, train_acc = train(epoch) File "demo.py", line 73, in train loss.backward() File "/home/yxbu/.conda/envs/torch/lib/python3.7/site-packages/torch/tensor.py", line 118, in backward torch.autograd.backward(self, gradient, retain_graph, create_graph) File "/home/yxbu/.conda/envs/torch/lib/python3.7/site-packages/torch/autograd/init.py", line 93, in backward allow_unreachable=True) # allow_unreachable flag File "/home/yxbu/.conda/envs/torch/lib/python3.7/site-packages/torch/autograd/function.py", line 77, in apply return self._forward_cls.backward(self, *args) File "/home/yxbu/code/spdnet/spdnet/spd.py", line 296, in backward Q = max_mask.diag().float() RuntimeError: _th_diag not supported on CUDAType for Bool 0it [00:00, ?it/s]

adavoudi commented 4 years ago

Hi, It may be due to version changes. I have no problem with PyTorch 1.1.0.

henanjun commented 4 years ago

It is possible to solve the problem without degrade the pytorch?

adavoudi commented 4 years ago

Yes, but unfortunately I am currently so busy. If you can send a PR it is really appreciated.

henanjun commented 4 years ago

Yes, but unfortunately I am currently so busy. If you can send a PR it is really appreciated.

Thanks for your kind reply. I have solved the problem by replace the "Q = max_mask.diag().float()" with "Q = max_mask.float().diag()."

dddc-1 commented 6 months ago

Hello, the link to download the data set cannot be opened. Is there any solution?