An issue with denormals in the phase loss calculation is causing the torch.angle function to return NaN values, as discussed here. A possible solution is to replace values close to zero with a certain threshold.
def replace_denormals(x: Tensor, threshold: float = 1e-10) -> Tensor:
"""Replace numbers close to zero to avoid NaNs in `angle`"""
y = x.clone()
y[torch.abs(x) < threshold] = threshold
return y
def angle(x: Tensor) -> Tensor:
"""Calculates the angle of a complex or real tensor"""
if torch.is_complex(x):
x_real = x.real
x_imag = x.imag
else:
x_real = x
x_imag = torch.zeros_like(x_real)
x_real = replace_denormals(x_real)
x_imag = replace_denormals(x_imag)
return torch.atan2(x_imag, x_real)
An issue with denormals in the phase loss calculation is causing the torch.angle function to return NaN values, as discussed here. A possible solution is to replace values close to zero with a certain threshold.