LouisDesdoigts / dLux

Differentiable optical models as parameterised neural networks in Jax using Zodiax
https://louisdesdoigts.github.io/dLux/
BSD 3-Clause "New" or "Revised" License
43 stars 6 forks source link

Custom Grads to Avoid Soft Edges #181

Closed Jordan-Dennis closed 1 year ago

Jordan-Dennis commented 1 year ago

Hi all, In amongst everything else I have been wanting to test using custom_jvp to define our own gradients for the apertures with respect to the parameters. I've been too afraid to try this with equinox objects but based on #180 I could do it with functions. The advantages are that hard edged apertures are easier to create (and faster) and that the accuracy is guaranteed. There are many technical issues to face though. For example, it is easy to define the custom derivatives for the power that is passing through the aperture, but whether or not this can then be used for other things (without defining a whole bunch of things) becomes the problem. Either way, I want to try and investigate custom_jvp. Regards Jordan

benjaminpope commented 1 year ago

I don't suggest we want to do this. A custom gradient of power is easy; but we actually need a partial derivative wrt pixels in order to do any backpropagation. Worse than this, a custom smooth gradient of a function with a discrete edge is not the gradient of that function, and without due care and attention will not behave nicely in optimization.

Jordan-Dennis commented 1 year ago

Yeah, as soon as I started trying I realised it was gonna be much harder than I thought. Came up with something like this for the gradient of a circular aperture with respect to the radius, but I'm not sure how to validate that. image

benjaminpope commented 1 year ago

Step function is the integral of a Dirac delta measure - so its gradient is not defined at the edge.