csteinmetz1 / auraloss

Collection of audio-focused loss functions in PyTorch
Apache License 2.0
695 stars 66 forks source link

Fix issue with denormals causing NaNs in phase loss #49

Open Kinyugo opened 1 year ago

Kinyugo commented 1 year ago

An issue with denormals in the phase loss calculation is causing the torch.angle function to return NaN values, as discussed here. A possible solution is to replace values close to zero with a certain threshold.

def replace_denormals(x: Tensor, threshold: float = 1e-10) -> Tensor:
    """Replace numbers close to zero to avoid NaNs in `angle`"""
    y = x.clone()
    y[torch.abs(x) < threshold] = threshold
    return y

def angle(x: Tensor) -> Tensor:
    """Calculates the angle of a complex or real tensor"""
    if torch.is_complex(x):
        x_real = x.real
        x_imag = x.imag
    else:
        x_real = x
        x_imag = torch.zeros_like(x_real)
    x_real = replace_denormals(x_real)
    x_imag = replace_denormals(x_imag)
    return torch.atan2(x_imag, x_real)