Closed carlosgmartin closed 11 months ago
Seems reasonable to me – one comment though, I think you'll find dividing by 2 to be much faster than ldexp
in JAX, because XLA has no ldexp primitive.
Looks like we did this!
@jakevdp Out of curiosity, is there a reason why XLA doesn't have an ldexp primitive, or has it just not been requested/added yet? Seems like it could potentially yield speed optimizations when users multiply/divide floats by a known power of 2.
Add the squareplus function, as described in Squareplus: A Softplus-Like Algebraic Rectifier, to jax.nn. It's a smooth rectifier like softplus, but without numerical stability issues and faster (1 algebraic function instead of 2 transcendental functions):
I can submit a PR.