LouisDesdoigts / dLux

Differentiable optical models as parameterised neural networks in Jax using Zodiax
https://louisdesdoigts.github.io/dLux/
BSD 3-Clause "New" or "Revised" License
43 stars 6 forks source link

Allow Selection of Soft Edges #183

Closed Jordan-Dennis closed 9 months ago

Jordan-Dennis commented 1 year ago

Hi all, I have implemented tanh, linear and sigmoid soft edges trying to get a performance edge. They all have more or less the same performance but it is easier to set the width precisely using linear. Should we allow the user to select the soft edging mode or just provide one?

Regards

Jordan

benjaminpope commented 1 year ago

If linear gives good performance - and I am anxious about this to be honest, because it drops to a hard zero which isn’t differentiable - well, if we benchmark them carefully and linear is fine, that’s ok by me.

——————————————— Dr Benjamin Pope (he/him) Lecturer in Astrophysics University of Queensland benjaminpope.github.io


From: Jordan Dennis @.> Sent: Friday, January 6, 2023 4:28:54 PM To: LouisDesdoigts/dLux @.> Cc: Subscribed @.***> Subject: [LouisDesdoigts/dLux] Allow Selection of Soft Edges (Issue #183)

Hi all, I have implemented tanh, linear and sigmoid soft edges trying to get a performance edge. They all have more or less the same performance but it is easier to set the width precisely using linear. Should we allow the user to select the soft edging mode or just provide one?

Regards

Jordan

— Reply to this email directly, view it on GitHubhttps://github.com/LouisDesdoigts/dLux/issues/183, or unsubscribehttps://github.com/notifications/unsubscribe-auth/ABN6YFIUVYFBWBOHX2QYOCTWQ6URNANCNFSM6AAAAAATSW6LSA. You are receiving this because you are subscribed to this thread.Message ID: @.***>

Jordan-Dennis commented 1 year ago

So interestingly the sigmoid returned a nan gradient. I expect this is because of numerical instability, but yes linear works. They are all roughly equivalent. I'm tempted to see if I can't get the linear to be differentiable in the limit of a hard edge. It might also be possible to do this with tanh.

LouisDesdoigts commented 1 year ago

Personally I am a fan of linear because you can set the hard threshold in direct pixel units. Still have never had any trouble with this method, and it has been optimised though. Its just applying a different function to the distance metric and using np.clip which is fully differentiable. The gradients of the pixels not close to edge will have zero gradient wrt aperture parameters anyway, this just swaps the function its calculating the gradient though to be linear. The thin boundary of pixels will always have defined gradients, its basically a ReLu function with a max and a min value.

I also got nan'd gradients when using sigmoid, I believe because of the very large and small values it calculates within the function resulting in overflow/underflow which kills gradients

LouisDesdoigts commented 1 year ago

But I see no reason why not keep both as options 🤷‍♂️

Jordan-Dennis commented 1 year ago

For the sigmoid the jax team give an example where they define a custom derivative function for a similar type of function for exactly that reason. We could do this, but it is similar enough to tanh that I don't think we necessarily need to worry.

We can always set this up using the STRATEGY design PATTERN.

Jordan-Dennis commented 1 year ago

This is insignificant and the np.tanh version was merged into main so we will ignore it from here.

LouisDesdoigts commented 1 year ago

Re-opening this as Id still like to implement a linear slope at some point in the future

Jordan-Dennis commented 1 year ago

I believe that the following code is one way of generating a linear slope, that was quite fast:

@ft.partial(jax.jit, inline=True)
def soften_v1(distances: float, nsoft: float, pixel_scale: float) -> float:
    lower: float = jax.lax.full_like(distances, 0., dtype=float)
    upper: float = jax.lax.full_like(distances, 1., dtype=float)
    inside: float = jax.lax.max(distances, lower)
    scaled: float = inside / nsoft / pixel_scale
    aperture: float = np.nanmin(scaled, upper)
    return aperture
LouisDesdoigts commented 9 months ago

No longer required after changes in #246