ejmahler / RustFFT

RustFFT is a high-performance FFT library written in pure Rust.
Apache License 2.0
684 stars 47 forks source link

Short-Time Fourier Transform support? #140

Open matbee-eth opened 4 months ago

matbee-eth commented 4 months ago

Would it be possible to get STFT / iSTFT support?

HEnquist commented 4 months ago

To me that feels more like something that belongs in separate library. How about https://crates.io/crates/ruststft ? It uses RustFFT for the FFT work.

phudtran commented 1 month ago

Would it be possible to get STFT / iSTFT support?

I made this to match PyTorch's output. It's still WIP and very slow compared to PyTorch but is pretty close in accuracy

https://github.com/phudtran/rustft

Testing with: 2 channels, signal length 16384, n_fft 1024, hop_length 512
Average Rust roundtrip error: 1.2937405892934302e-16
Average PyTorch roundtrip error: 9.561864467600023e-09
Average roundtrip error (Rust STFT -> PyTorch ISTFT): 1.771066294604266e-08
Average roundtrip error (PyTorch STFT -> Rust ISTFT): 1.5421215870497426e-08

Average run times:
Rust STFT + ISTFT: 0.058335 seconds
PyTorch STFT + ISTFT: 0.005722 seconds
Rust STFT: 0.029399 seconds
PyTorch ISTFT: 0.000716 seconds
PyTorch STFT: 0.000272 seconds
Rust ISTFT: 0.026649 seconds

Testing with: 4 channels, signal length 32768, n_fft 2048, hop_length 1024
Average Rust roundtrip error: 1.2302921848087881e-16
Average PyTorch roundtrip error: 9.679753242194043e-09
Average roundtrip error (Rust STFT -> PyTorch ISTFT): 1.7571041522967382e-08
Average roundtrip error (PyTorch STFT -> Rust ISTFT): 1.5074709989037257e-08

Average run times:
Rust STFT + ISTFT: 0.224398 seconds
PyTorch STFT + ISTFT: 0.021696 seconds
Rust STFT: 0.117436 seconds
PyTorch ISTFT: 0.002956 seconds
PyTorch STFT: 0.001173 seconds
Rust ISTFT: 0.106868 seconds

Testing with: 8 channels, signal length 65536, n_fft 4096, hop_length 2048
Average Rust roundtrip error: 1.379474655196007e-16
Average PyTorch roundtrip error: 9.650499519246785e-09
Average roundtrip error (Rust STFT -> PyTorch ISTFT): 1.7393561622267553e-08
Average roundtrip error (PyTorch STFT -> Rust ISTFT): 1.5055924811920244e-08

Average run times:
Rust STFT + ISTFT: 0.986159 seconds
PyTorch STFT + ISTFT: 0.062225 seconds
Rust STFT: 0.510961 seconds
PyTorch ISTFT: 0.009001 seconds
PyTorch STFT: 0.003260 seconds
Rust ISTFT: 0.475791 seconds
HEnquist commented 1 month ago

Nice! To speed it up you should create a single planner and reuse that for every FFT call. The way it works now, where it creates a new planner every time, means that it spends more time creating the planner and the FFT instance than performing the actual transform.

phudtran commented 1 month ago

Nice! To speed it up you should create a single planner and reuse that for every FFT call. The way it works now, where it creates a new planner every time, means that it spends more time creating the planner and the FFT instance than performing the actual transform.

Thank you for the suggestion! I'm reusing both the planner and windows, but both forward and inverse are still very slow compared to PyTorch (especially inverse). Any idea how can improve it further?

HEnquist commented 1 month ago

I don't see how the planner gets reused between FFT calls, here is looks like it creates a new one on each call: https://github.com/phudtran/rustft/blob/7bdcf44d191302f022c1a7553a7180d90e12af54/benchmarks/src/lib.rs#L49 You should have a single planner that gets used by all the calls, but I'm not sure how you would go about doing that when calling from python.

I would start a little simpler by implementing some benchmarks in rust first. At the moment it's not clear what step (or steps) in a processing call is responsible for the slowness.

phudtran commented 1 month ago

I don't see how the planner gets reused between FFT calls, here is looks like it creates a new one on each call: https://github.com/phudtran/rustft/blob/7bdcf44d191302f022c1a7553a7180d90e12af54/benchmarks/src/lib.rs#L49

You should have a single planner that gets used by all the calls, but I'm not sure how you would go about doing that when calling from python.

I would start a little simpler by implementing some benchmarks in rust first. At the moment it's not clear what step (or steps) in a processing call is responsible for the slowness.

Yeah unfortunately I can't pass in a planner from Python, so the Python benchmarks have not been updated. I'll time the function in rust instead after the planner is created and return that along with the result.

WalterSmuts commented 1 month ago

Yeah unfortunately I can't pass in a planner from Python

Not sure if this is completely true, but consider using easyfft that manages the planner for you.