jax-ml / jax

Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more
http://jax.readthedocs.io/
Apache License 2.0
30.32k stars 2.78k forks source link

Add nn.squareplus #18519

Closed carlosgmartin closed 11 months ago

carlosgmartin commented 11 months ago

Add the squareplus function, as described in Squareplus: A Softplus-Like Algebraic Rectifier, to jax.nn. It's a smooth rectifier like softplus, but without numerical stability issues and faster (1 algebraic function instead of 2 transcendental functions):

import timeit

import jax
from jax import numpy as jnp, nn, random

def squareplus(x, b=4):
    y = x + jnp.sqrt(jnp.square(x) + b)
    return jnp.ldexp(y, -1)  # fast division by 2

key = random.PRNGKey(0)
x = random.normal(key, [10**8])

jit_softplus = jax.jit(nn.softplus)
jit_squareplus = jax.jit(squareplus)

jit_softplus(x).block_until_ready()
jit_squareplus(x).block_until_ready()

start = timeit.default_timer()
jit_softplus(x).block_until_ready()
print(timeit.default_timer() - start)

start = timeit.default_timer()
jit_squareplus(x).block_until_ready()
print(timeit.default_timer() - start)
0.23485908400016342
0.0861067619998721

I can submit a PR.

jakevdp commented 11 months ago

Seems reasonable to me – one comment though, I think you'll find dividing by 2 to be much faster than ldexp in JAX, because XLA has no ldexp primitive.

hawkinsp commented 11 months ago

Looks like we did this!

carlosgmartin commented 7 months ago

@jakevdp Out of curiosity, is there a reason why XLA doesn't have an ldexp primitive, or has it just not been requested/added yet? Seems like it could potentially yield speed optimizations when users multiply/divide floats by a known power of 2.