QUVA-Lab / e2cnn

E(2)-Equivariant CNNs Library for Pytorch
https://quva-lab.github.io/e2cnn/
Other
599 stars 74 forks source link

RuntimeError: a view of a leaf Variable that requires grad is being used in an in-place operation. #34

Closed SunHaozhe closed 3 years ago

SunHaozhe commented 3 years ago

I believe that the package needs to be updated for new versions of PyTorch. The code works with PyTorch 1.1, but it does not work for PyTorch 1.8.

This issue can be reproduced via the use of another package:

  1. git clone https://github.com/AllanYangZhou/metalearning-symmetries.git
  2. python generate_synthetic_data.py --problem 2d_rot8

Another issue is the support of machine without CUDA. To reproduce this issue, execute the following two steps on a CPU-only machine:

  1. git clone https://github.com/AllanYangZhou/metalearning-symmetries.git
  2. python generate_synthetic_data.py --problem 2d_rot8_flip
Gabri95 commented 3 years ago

Hi @SunHaozhe

Unfortunately, I can not test the library with Pytorch 1.8 on CUDA at the moment.

However, looking at your code, I think you can fix the runtime error in the title by using for example gnn.init.generalized_he_init(conv.weights.data, conv.basisexpansion) instead of gnn.init.generalized_he_init(conv.weights, conv.basisexpansion)

The problem is that the weight initialization methods operate in-place (like PyTorch's ones). I see that PyTorch uses with torch.no_grad() inside the init methods, so maybe I should do the same.

I need to think a bit more about this to make sure it wouldn't cause other problems, but in the meantime you can use the trick above to solve your issue.

Regarding the second issue, I ran the script you mentioned on a machine without GPU (I used my own conda environment with PyTorch 1.8) but I didn't encounter any issue. However, note that generate_synthetic_data assumes you have a GPU since it explicitly uses .to('cuda'). That means that you need to remove that stametement from your code if you want to run it on CPU.

Let me know if this solves your issues

Best, Gabriele

SunHaozhe commented 3 years ago

Thanks a lot for your reply! I indeed solved the issue by switching to a machine with CUDA and the correct PyTorch version. This issue can be helpful for everyone who encounter the same problem (if any) :)