jonkhler / s2cnn

Spherical CNNs
MIT License
939 stars 176 forks source link

Running MNIST Example Problems #52

Open kkokilep opened 3 years ago

kkokilep commented 3 years ago

Hello,

Thanks for the great work! I have some issues with getting this code to run, starting with the example given in the repository. Basically, when I try to do python run.py in the mnist folder, I get a bunch of different errors that I feel like I shouldn't have to fix in order for the code to work since this is the basic example. The first error is this:

Traceback (most recent call last): File "run.py", line 257, in main(args.network) File "run.py", line 221, in main outputs = classifier(images) File "/home/kiran/anaconda3/envs/OCT_latent_space/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1051, in _call_impl return forward_call(*input, *kwargs) File "run.py", line 89, in forward x = self.conv1(x) File "/home/kiran/anaconda3/envs/OCT_latent_space/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1051, in _call_impl return forward_call(input, *kwargs) File "/home/kiran/Desktop/Dev/s2cnn/mnist/s2cnn/soft/s2_conv.py", line 40, in forward x = S2_fft_real.apply(x, self.b_out) # [l m, batch, feature_in, complex] File "/home/kiran/Desktop/Dev/s2cnn/mnist/s2cnn/soft/s2_fft.py", line 233, in forward return s2_fft(as_complex(x), b_out=ctx.b_out) File "/home/kiran/Desktop/Dev/s2cnn/mnist/s2cnn/soft/s2_fft.py", line 56, in s2_fft output[s] = torch.einsum("bm,zbmc->mzc", (wigner[:, s], xx)) File "/home/kiran/anaconda3/envs/OCT_latent_space/lib/python3.8/site-packages/torch/functional.py", line 297, in einsum return einsum(equation, *_operands) File "/home/kiran/anaconda3/envs/OCT_latent_space/lib/python3.8/site-packages/torch/functional.py", line 299, in einsum return _VF.einsum(equation, operands) # type: ignore[attr-defined] RuntimeError: expected scalar type Float but found ComplexFloat

If I try to fix this, by setting xx to a real tensor then I get a bunch of errors down the line. I was wondering if anyone had any advice with this.

mariogeiger commented 3 years ago

I guess it's due to the fact that pytorch fft changed. Someone made recently a PR to call the new pytorch api functions but maybe it was not sufficient.

I don't maintain this code anymore... But I maintain this one and it might contains the functionalities that you need.

EricPengShuai commented 2 years ago

@kkokilep @mariogeiger In your MNIST example, running S2CNN encounters the following problem about lie_learn module.

Traceback (most recent call last):
  File "<frozen importlib._bootstrap>", line 983, in _find_and_load
  File "<frozen importlib._bootstrap>", line 967, in _find_and_load_unlocked
  File "<frozen importlib._bootstrap>", line 677, in _load_unlocked
  File "<frozen importlib._bootstrap_external>", line 728, in exec_module
  File "<frozen importlib._bootstrap>", line 219, in _call_with_frames_removed
  File "C:\Users\peng\anaconda3\envs\pytorch_1.8\lib\site-packages\lie_learn\representations\SO3\wigner_d.py", line 5, in <module>
    from lie_learn.representations.SO3.irrep_bases import change_of_basis_matrix
ModuleNotFoundError: No module named 'lie_learn.representations.SO3.irrep_bases'
mxiao18 commented 2 years ago

Same issue. But I notice that in the so3_fft routine, there is conditional argument testing whether the input is a CUDA device tensor. You could come into this issue if you have not launched your container correctly (e.g. without GPU) or CUDA is not set up correctly. I finally fix it by running the container with GPU on.