LouisDesdoigts / dLux

Differentiable optical models as parameterised neural networks in Jax using Zodiax
https://louisdesdoigts.github.io/dLux/
BSD 3-Clause "New" or "Revised" License
50 stars 6 forks source link

Increase Primitive Usage. #178

Closed Jordan-Dennis closed 1 year ago

Jordan-Dennis commented 1 year ago

Hi all, Based on the contents of #169 we can improve the performance by lowering our code nearer to lax. The amount to which this is effective will, I think, vary. In most cases jax does all the lowering itself giving so it will not be worth the effort but in other cases I think it will give us better performance. For example, consider my MWE using np.linspace below. This is a micro-optimisation, but when we are using the coordinates function as much as we do something like this might be handy. image

Regards Jordan

Jordan-Dennis commented 1 year ago

Here is another example, image This one actually blows my mind.

Jordan-Dennis commented 1 year ago

So Jake Vanderpass suggested (not strictly speaking) against this and said to "trust the compiler". When I suggested it I was under the impression that the jaxpr roughly translated into runtime performance, but Jake said this is not the case. It is still nice to have a fine level of control over the jaxpr (still an important step) since using primitives is almost one to one. However, for the most case I don't think we need to go that far.