QUVA-Lab / e2cnn

E(2)-Equivariant CNNs Library for Pytorch
https://quva-lab.github.io/e2cnn/
Other
599 stars 74 forks source link

FieldType.transform doesn't work on GPU tensors #39

Closed ejnnr closed 3 years ago

ejnnr commented 3 years ago

When I transform a tensor that's on a GPU (using FieldType.transform or GeometricTensor.transform), I get TypeError: can't convert cuda:0 device type tensor to numpy. The issue is that in this line, it would have to be input.detach().cpu().numpy() instead of input.detach().numpy().

Only changing that line is probably bad because it would return a tensor on CPU when a GPU tensor is fed in. You could store the device of the tensor, move the tensor to CPU, and at the end return a new tensor on the original device. Of course, moving the tensor implicitly like that isn't ideal as well because it might be a performance hit without the user being aware of it. Alternatively, maybe just document this and raise a custom Exception ("tensors need to be moved to cpu before transforming")? Not sure which solution is better.

kristian-georgiev commented 3 years ago

Not the author, but it seems like it might be easier/more efficient if np.einsum here and np.rot90 in the different ._basespace_actions are replaced with their torch equivalents. This way, there should be no need to move the tensor between the GPU and CPU(?)

Gabri95 commented 3 years ago

hey there! Sorry for the late reply.

I agree that both .transform() methods should raise an exception in that case.

Currently, the backend of the library dealing with groups, representations and generation of the kernels only relies on numpy. The e2cnn.nn subpackage, then, tries to provide an interface similar to PyTorch and wraps the np.ndarray in the other subpackages inside torch.Tensor. The rationale was that I wanted the backend of the library to be independent of PyTorch (such that one could also extract the filter bases for other applications). I am now a bit more convinced that the backend should also support PyTorch, or at least the e2cnn.gspace should.

For the moment, I'd still recommend not directly using tensor.transform(el) when this implies some interpolation of the input. This is because one probably wants to have the freedom to customise more this operation (e.g. change the order of the interpolation).

Anyways, in a future release, I will try to fix the gspace as @kristian-georgiev suggested, using torch operations.

Thanks! Gabriele