google-deepmind / optax

Optax is a gradient processing and optimization library for JAX.
https://optax.readthedocs.io
Apache License 2.0
1.65k stars 181 forks source link

Add axis and where arguments to loss functions. #912

Closed carlosgmartin closed 1 month ago

carlosgmartin commented 5 months ago

902

vroulet commented 5 months ago

Thanks @carlosgmartin, could you add tests? Also you will need to wait for #916 to pass.

fabianp commented 3 months ago

@carlosgmartin , there are some conflicts with main, do you mind updating the pull request? thanks

carlosgmartin commented 1 month ago

@vroulet @fabianp Let me know if you'd like me to make any other changes.

fabianp commented 1 month ago

hey @carlosgmartin , thanks for the ping and apologies for the late reply, most of the team is on vacation 🏖️

The addition of the axis argument looks good to me, and the use of an axis kwarg if fairly common in numpy-like functions. ✅

Regarding the "where" argument however, I haven't seen it yet in other libraries. Do you know of any numpy/jax/flax/etc. functions that admit a "where" or a "mask" or similar kwarg? I just want to make sure that our API is as similar as possible to other libraries that have already implemented similar functionality.

carlosgmartin commented 1 month ago

Most reduction functions in numpy/jax take a where argument for masking. For example:

Conceptually, it makes sense that any reduction function should take both axis and where arguments.

fabianp commented 1 month ago

excellent, thanks for the info!

fabianp commented 1 month ago

one more thing: could you please add the tag .. versionchanged:: 0.2.4 to the docstrings of the functions you've changed explaining the change? See for example here for an example: https://github.com/google-deepmind/optax/blob/main/optax/schedules/_inject.py#L114

fabianp commented 1 month ago

also, the new kwargs should be described in the function docstring (on top of adding the versionchanged tag)

carlosgmartin commented 1 month ago

@fabianp Done.

fabianp commented 1 month ago

thanks, we're almost there. Please add type annotations for these two new kwargs. These are likely

where: chex.Array | None = None,
axis: int | tuple[int, ...] | None = -1
carlosgmartin commented 1 month ago

@fabianp Done.