fbcotter / pytorch_wavelets

Pytorch implementation of 2D Discrete Wavelet (DWT) and Dual Tree Complex Wavelet Transforms (DTCWT) and a DTCWT based ScatterNet
Other
943 stars 146 forks source link

A problem about multiple GPUs training. #27

Closed freepoet closed 3 years ago

freepoet commented 3 years ago

Hi,fbcotter. I'm very interested in your dissertation and this github repo.
I met a problem when I called DTCWTForward().
Can DTCWTForward() support training on multiple GPUs. I want to _import pytorchwavelets , and then train CNN on 8 GPUS .
Thx!!!

fbcotter commented 3 years ago

Hi @freepoet, sorry for not getting back to you! Yes it should support this naturally as the classes all inherit from torch.nn.Module did you manage to get it working?

Mojzaar commented 2 years ago

I have the same problem. When I want to use DWTForward as down-sampling layer inside of my model, the resultant output would remain on the cpu! While, I have used the pytorch Dataparallel to move the entire of the model on multiple GPUs. Any suggestion?