LouisDesdoigts / dLux

Differentiable optical models as parameterised neural networks in Jax using Zodiax
https://louisdesdoigts.github.io/dLux/
BSD 3-Clause "New" or "Revised" License
50 stars 6 forks source link

Investigate Secondary Binding For `xla::cond`. #185

Closed Jordan-Dennis closed 1 year ago

Jordan-Dennis commented 1 year ago

Hi all This is a personal reminder to investigate looking into to binding the xla::cond primitive to dLux::cond. This would be the exact same as jax::cond but it would not compile the branch functions. I just want to do this so the jax.profiler can be used more effectively. Regards Jordan.

Jordan-Dennis commented 1 year ago

Can just use if statements. I just need to be careful with compiling.