SonyCSLParis / pesto

Self-supervised learning for fast pitch estimation
GNU Lesser General Public License v3.0
168 stars 15 forks source link

CQT kernel device is reset when sample rate is changed #17

Closed ben-hayes closed 8 months ago

ben-hayes commented 8 months ago

Expected behavior

After calling DataProcessor.to(device), submodule parameters and buffers (i.e. CQT kernels) should remain on device.

Actual behaviour

If DataProcessor.sampling_rate is changed, a new CQT submodule is constructed on self.device, which is set on construction and not updated on calls to .to(device), .cpu(), or .cuda().

Minimal example

import pesto                                           

dp = pesto.utils.load_dataprocessor(1e-2, device="cpu")
dp.sampling_rate = 44100                               
print(dp.cqt.cqt_kernels_real.device)                  
print(dp.cqt.cqt_kernels_imag.device)                  

dp.to("cuda")                                          
print(dp.cqt.cqt_kernels_real.device)                  
print(dp.cqt.cqt_kernels_imag.device)                  

dp.sampling_rate = 48000                               
print(dp.cqt.cqt_kernels_real.device)                  
print(dp.cqt.cqt_kernels_imag.device)                  

Outputs:

cpu   
cpu   
cuda:0
cuda:0
cpu   
cpu  

We would expect the last two lines to be cuda:0.