wavefrontshaping / complexPyTorch

A high-level toolbox for using complex valued neural networks in PyTorch
MIT License
610 stars 148 forks source link

ComplexDropout2d Device Error #30

Open lucacoma opened 1 year ago

lucacoma commented 1 year ago

Hi, thank you for the nice library.

There seems to be a small mistake in the complexPyTorch.complexLayers.ComplexDropout2d layer, which gives a device mismatch error (torch version 2.0.1+cu118):

""" .... line 106, in complex_dropout return mask*input RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cpu! """

I managed to solve it by simply moving the mask on the right device in complexPyTorch.complexFunctions.complex_dropout2d as follows

` def complex_dropout2d(input, p=0.5, training=True):

# need to have the same dropout mask for real and imaginary part,

# this not a clean solution!

device = input.device

mask = torch.ones(*input.shape, dtype = torch.float32, device = device)

mask = torch.nn.functional.dropout2d(mask, p, training)*1/(1-p)

mask.type(input.dtype)

mask = mask.to(device) # Line added

return mask*input`

Best!

nctamer commented 2 weeks ago

the same for all the dropouts. any updates for the official fix?

wavefrontshaping commented 4 days ago

Hi, I can look at it when I have some time but I am not working anymore on this code, which I consider obsolete and do not need anymore due to the implementation of complex tensors in the current versions of PyTorch. Do not hesitate to fork and why not make a pull request though if you need such changes, I would treat it. Best,