csteinmetz1 / dasp-pytorch

Differentiable audio signal processors in PyTorch
https://csteinmetz1.github.io/dasp-pytorch/
Apache License 2.0
213 stars 4 forks source link

Distortion module sample rate #5

Open oreillyp opened 3 months ago

oreillyp commented 3 months ago

Because the Distortion module does not have a sample_rate attribute, calling .process_normalized() fails:

     42 denorm_param_dict = self.denormalize_param_dict(param_dict)
     44 # now process audio with denormalized parameters
     45 y = self.process_fn(
     46     x,
---> 47     self.sample_rate,
     48     **denorm_param_dict,
     49 )
     51 return y

AttributeError: 'Distortion' object has no attribute 'sample_rate'

One solution would be to add a sample_rate attribute, even if it is a no-op; another would be to store a default sample rate of None in the Processor class to support modules that do not require a sample rate.

oreillyp commented 3 months ago

It also looks like the parameter name within the module does not match the distortion function signature -- it should be drive_db. The following fixes both errors:

class Distortion(Processor):
    def __init__(
        self,
        sample_rate: int = None,
        min_drive_db: float = 0.0,
        max_drive_db: float = 24.0,
    ):
        super().__init__()
        self.sample_rate = sample_rate
        self.process_fn = distortion
        self.param_ranges = {
            "drive_db": (min_drive_db, max_drive_db),
        }
        self.num_params = len(self.param_ranges)