google / jax

Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more
http://jax.readthedocs.io/
Apache License 2.0
29.95k stars 2.75k forks source link

Add `jax.nn.normalize` #20556

Open carlosgmartin opened 5 months ago

carlosgmartin commented 5 months ago

jax.nn contains the function jax.nn.standardize, which does standardization.

Another very common operation in data processing is min-max normalization, which rescales values to an interval.

Feature request: Add a function jax.nn.normalize that does this.

Possible implementation:

def normalize(a, axis=None, where=None, lb=0, ub=1):
    """Applies a min-max normalization of values to the interval [lb, ub]."""
    a_max = a.max(axis=axis, where=where, keepdims=True, initial=-jnp.inf)
    a_min = a.min(axis=axis, where=where, keepdims=True, initial=+jnp.inf)
    a_ptp = a_max - a_min
    a_ptp = jnp.where(a_ptp == 0, 1, a_ptp)  # avoid division by zero
    a_norm = (a - a_min) / a_ptp
    return lb + (ub - lb) * a_norm

I can submit a PR for this.

jakevdp commented 5 months ago

jax.nn.normalize already exists: it's a deprecated alias to jax.nn.standardize. Before we add another function of the same name, we'll have to finalize the deprecation, and then it would probably be good to wait several releases before introducing any new normalize function in order to avoid confusion.

carlosgmartin commented 5 months ago

I see. How long would that take?

Would you prefer to wait until then, or to use a different name, such as minmax_normalize?

jakevdp commented 5 months ago

I'm not entirely convinced of the need for this function in jax.nn.