MPoL-dev / MPoL

A flexible Python platform for Regularized Maximum Likelihood imaging
https://mpol-dev.github.io/MPoL/
MIT License
34 stars 11 forks source link

`fourier.py` type hinted #172

Closed kadri-nizam closed 1 year ago

kadri-nizam commented 1 year ago

The module is fully type hinted now. I also simplified certain logic, mostly where branching can be removed. The logic changes are in the NUFFT __init__, forward, and _assemble_ktraj method.

kadri-nizam commented 1 year ago

... are all handled and broadcast correctly using the batch/coil dimensions as needed?

I went through the logic again and believe it is the same as the previous implementation. I'm happy to pair-code with you over one of our hackathon sessions to ensure I didn't miss anything though. We can keep the PR open until then.


I'd like to preserve the ability to use sparse matrices with multi-channel image cubes with same_uv=False in the future, does flow do that?

I based the refactoring logic on the current implementation and I don't think it currently accounts for that case right now. If sparse_matrices=True and same_uv=False, the else clause on L253 gets executed which sets sparse_matrices=False on L261. https://github.com/MPoL-dev/MPoL/blob/0cd0042c506cd683302d0dae3a1a0a667798ee0a/src/mpol/fourier.py#L244-L261 How would you like to proceed?


Also, can you say a little more about what the protocol is doing?

Sure! Python's Protocol helps with structural duck-typing. In this case, the function:

def get_vis_residuals(
    model: MPoLModel,
    ...

has a required parameter model which generally is a sub-class of nn.Module. However, we also expect it to have additional attributes that are related to MPoL classes (GridCoords, ImageCube, etc; the precomposed.SimpleNet model is an example). So with the MPoLModel protocol:

class MPoLModel(Protocol):
    coords: mpol.coordinates.GridCoords
    nchan: int
    bcube: mpol.images.BaseCube
    icube: mpol.images.ImageCube
    fcube: mpol.fourier.FourierCube

    def forward(self):
        ...

Python linters will let the user know that any custom model they pass to the function must -- at a minimum -- implement the attributes as defined in the protocol in addition to a forward method to satisfy both requirements of MPoL and the nn.Module.

iancze commented 1 year ago

I went through the logic again and believe it is the same as the previous implementation. I'm happy to pair-code with you over one of our hackathon sessions to ensure I didn't miss anything though. We can keep the PR open until then.

Great, I think that will be the quickest route to make sure we're both on the same page.

The protocol stuff sounds nice, thanks.