wavefrontshaping / complexPyTorch

A high-level toolbox for using complex valued neural networks in PyTorch
MIT License
610 stars 148 forks source link

Complex MSELoss #27

Open dstark1993 opened 1 year ago

dstark1993 commented 1 year ago

Similar to torch.nn.MSELoss().

I guess the function is pretty obvious as seen in https://github.com/pytorch/pytorch/issues/46642

def complex_mse_loss(output, target):
    return (0.5*(output - target)**2).mean(dtype=torch.complex64)