OverLordGoldDragon / ssqueezepy

Synchrosqueezing, wavelet transforms, and time-frequency analysis in Python
MIT License
599 stars 92 forks source link

How to use torch tensors already on GPU? #84

Closed naba89 closed 1 year ago

naba89 commented 1 year ago

Hi, Thanks for your wonderful work.

Currently, only NumPy arrays are supported as input. Is there a way to use torch tensors already on the GPU?

If not, if you give me some pointers, I can work on implementing that feature.

Best regards Nabarun

OverLordGoldDragon commented 1 year ago

Hello, it's not supported, but performance gains are going to be minimal; the compute itself takes much longer than x.cpu().numpy(). Implementation requires making torch equivalents of operations done upon input in every method, STFT, CQT, SSQ, etc - as far as I'm concerned it's an unnecessary complication.

naba89 commented 1 year ago

Thanks for the comment, I was not looking at it from a performance point of view of the SSQ computation, but to be able to plug this into other existing PyTorch pipelines. In this case, we need to first copy a tensor from the GPU to the host, then again to the GPU for SSQ*, back to the host, and again to the GPU for the rest of the pipeline, in the current scenario. This adds up to a significant overhead. If the functions in this module are able to accept on-device tensors, it would be really good for such use cases.

Another point in case might be to make certain operations differentiable if the underlying function is inherently differentiable, which would not be possible if there are torch<->numpy conversions.

OverLordGoldDragon commented 1 year ago

we need to first copy a tensor from the GPU to the host, then again to the GPU for SSQ*

Yes, this is negligible (maybe not with padtype != None but I've not checked the extent of slowdown)

back to the host, and again to the GPU for the rest of the pipeline

No? Check the astensor argument for cwt, ssq_cwt and ssq_stft. stft stays on GPU.

make certain operations differentiable

Indeed, but that's certainly beyond the scope of this library. If specifically differentiable CWT is of interest, stay tuned on the README.

OverLordGoldDragon commented 1 year ago

So I just checked and the only things stopping torch tensors as input to cwt, if padtype == None, are the argument checks and x.astype(dtype). I think an exception can be coded reasonably in this case, but I'd need to see some benchmarks that show this to be a problem.

naba89 commented 1 year ago

Ok, so the padding is done entirely in numpy. That might be workable with padtype==None for now, I will check that. What kind of benchmarks did you have in mind? Do you mean test cases or some performance benchmarks?

OverLordGoldDragon commented 1 year ago

Speed, that shows the overhead of x.cpu().numpy() + torch.as_tensor(x, device='cuda').

naba89 commented 1 year ago

Ok, let me see what I can do, will get back to you. Thanks for considering it though.

OverLordGoldDragon commented 1 year ago

I confirmed the impact to be <10% even with padtype != None for len(x) == 131072, but investigating some code I copied without a close look I found some completely redundant computation, and a surprising bottleneck by p2up that more than halves GPU speed... Definitely fixing that one, and will look into "already-GPU" support.

OverLordGoldDragon commented 1 year ago

False positive on p2up but still did some miscellaneous cleanups.