SyneRBI / SIRF

Main repository for the CCP SynerBI software
http://www.ccpsynerbi.ac.uk
Other
59 stars 29 forks source link

Feature Request: PyTorch Autograd support #901

Open ashgillman opened 3 years ago

ashgillman commented 3 years ago

This was discussed in Tuesday/Wednesdays Training School Meeting.

This is done by overriding a torch.autograd.Function object, and should define forward and backward.

https://pytorch.org/tutorials/beginner/examples_autograd/two_layer_net_custom_function.html https://pytorch.org/docs/stable/notes/extending.html

Here is an example of the implementation in PyroNN Torch: https://github.com/theHamsta/pyronn-torch/blob/master/src/pyronn_torch/parallel.py

I wonder if we would have something as simple as:

class _SIRFFunction(torch.autograd.Function):
    @staticmethod
    def forward(ctx, x, sirf_obj):
        ctx.sirf_obj = sirf_obj
        return sirf_obj.forward(x)

    @staticmethod
    def backward(ctx, y):
        return ctx.sirf_obj.adjoint(y)

class PyTorchWrapper(torch.nn.Module):  # not sure if we actually want to extend module...
    def __init__(self, sirf_obj):
        self.sirf_obj = sirf_obj

    def forward(self, x):
        return _SIRFFunction.apply(x, self.sirf_obj)

I'll give it a test to see if it works

KrisThielemans commented 3 years ago

ODL implementation for PyTorch and TensorFlow. Used for instance here. Unfortunately, we cannot straight-copy due to the Mozilla license, but it can provide inspiration.

ashgillman commented 3 years ago

Looking at ODL's implementation, the example above omits conversion to/from pytorch (duh) and quick escape if no grad required, But otherwise, conceptually its very similar. A lot of stuff for linear ops that I don't know we'll need but that could be naive. I'll give it a spin next week anyway.

KrisThielemans commented 3 years ago

Anyone tried this?

ashgillman commented 3 years ago

I started playing last week but haven't finished anything quite yet - do you want a prototype and when?

On Mon, 31 May 2021, 7:03 pm Kris Thielemans, @.***> wrote:

Anyone tried this?

— You are receiving this because you were assigned. Reply to this email directly, view it on GitHub https://github.com/SyneRBI/SIRF/issues/901#issuecomment-851340674, or unsubscribe https://github.com/notifications/unsubscribe-auth/AAGHNCLBPEU5A5EZY2H66ITTQNGGTANCNFSM43KDJ2RQ .

KrisThielemans commented 3 years ago

Thanks @ashgillman. If we have anything that could work, it'd be fun to merge it into 3.1, just in case... But there's quite a few other things on your plate of course.

ashgillman commented 3 years ago

Haha, the above wasn't a promise :p but good to know useful timelines. I'll have a look again before the dev meeting Thursday

ckolbPTB commented 2 years ago

@ashgillman did you do any more work on this recently? we will give it a try for the MR side in a few weeks time, so any info about problems/solutions/dead-ends you have encountered would be much appreciated

ashgillman commented 2 years ago

@ckolbPTB Apologies for missing this. No, I didn't end up yet

KrisThielemans commented 2 years ago

This is in progress at UCL. PR coming in 1-2 weeks

KrisThielemans commented 2 years ago

One of the things we need to do of course is to convert SIRF to torch objects. Current WIP code has statements like torch.fromnumpy(sirf_object.as_array()). This might be inefficient of course and would need to be addressed later.

I wonder if it would make sense to add methods DataContainer::as_torch() and DataContainer::fill(torch_object). Initially we can do this only on the Python side, and then go via numpy. Later on, we can do it better. This will make code cleaner and future-proof