QUVA-Lab / e2cnn

E(2)-Equivariant CNNs Library for Pytorch
https://quva-lab.github.io/e2cnn/
Other
596 stars 75 forks source link

Fix for torch >= 1.9 #44

Closed marcelroed closed 3 years ago

marcelroed commented 3 years ago

In its current form from e2cnn import nn results in a crash for torch >= 1.8. Here's a fix.

Source for this fix: this issue.

Gabri95 commented 3 years ago

Hey there!

Thanks for spotting this error!

I will merge this pull request in a moment

Best, Gabriele

Gabri95 commented 3 years ago

Hi @marcelroed

I tried from e2cnn import nn in a conda environment with PyTorch 1.8.1 and e2cnn 1.9 but I did not find any import error. From the link you posted, it seems this problem arises only for PyTorch >= 1.9. Indeed, I can reproduce the error by using PyTorch 1.9.

I would then update your fix with the condition TORCH_MINOR <= 8 rather than TORCH_MINOR < 8.

Is that ok?

marcelroed commented 3 years ago

Hi @marcelroed

I tried from e2cnn import nn in a conda environment with PyTorch 1.8.1 and e2cnn 1.9 but I did not find any import error. From the link you posted, it seems this problem arises only for PyTorch >= 1.9. Indeed, I can reproduce the error by using PyTorch 1.9.

I would then update your fix with the condition TORCH_MINOR <= 8 rather than TORCH_MINOR < 8.

Is that ok?

Hey @Gabri95, I changed the version check to be <= 1.8.

Thanks for testing this with 1.8 and 1.9!

Gabri95 commented 3 years ago

Perfect, thanks again for your contribution!

Best, Gabriele