wavefrontshaping / complexPyTorch

A high-level toolbox for using complex valued neural networks in PyTorch
MIT License
623 stars 148 forks source link

Add default dtype to constructors and other complex methods to allow flexibility for users #36

Closed rmclimber closed 4 months ago

rmclimber commented 5 months ago

I would like to use this package with higher precision, so I have added the ability to pass dtypes.

I went through all the functions and classes in the package. For each class in complexLayers.py, I added a dtype=torch.complex64 optional argument to the constructor and a self.dtype=dtype class member. In class methods, I replaced torch.complex64 with self.dtype. For each method in complexFunctions.py which uses the dtype at all, I added a dtype=torch.complex64 optional argument and replaced torch.complex64 with dtype in the method body.

Without a testing suite I can't verify nothing broke with confidence, but I used default arguments so that current users do not need to change their function/constructor calls. Pretty sure this should be entirely backward-compatible with existing work.

rmclimber commented 4 months ago

Believe I caught a bug so I'll remove this for now. I'll open a new one later once I have everything fixed.