google / flax

Flax is a neural network library for JAX that is designed for flexibility.
https://flax.readthedocs.io
Apache License 2.0
6.17k stars 649 forks source link

NNX documentation missing pooling operations #4271

Open tux-type opened 1 month ago

tux-type commented 1 month ago

The NNX API reference does not seem to have entries for pooling operations like nnx.avg_pool or nnx.max_pool.

cgarciae commented 1 month ago

@8bitmp3

8bitmp3 commented 1 month ago

thanks @tux-type @cgarciae

8bitmp3 commented 3 weeks ago

Can you help find the source code @cgarciae https://github.com/search?q=repo%3Agoogle%2Fflax+path%3A%2F%5Eflax%5C%2Fnnx%5C%2F%2F+max_pool&type=code

jorisSchaller commented 1 week ago

Hey, I was also looking at the pooling, the flax/core/nn/ __init__.py file import the function from the old linen API. https://github.com/google/flax/blob/5d896bc1a2c68e2099d147cd2bc18ebb6a46a0bd/flax/core/nn/__init__.py#L38

But the entire file uses only https://github.com/google/flax/blob/5d896bc1a2c68e2099d147cd2bc18ebb6a46a0bd/flax/linen/pooling.py#L17-L19 We should copy it to flax/core/nn/pooling.py in order to have have it in the new nnx API. We should also copy the pooling doc from https://github.com/google/flax/blob/main/docs/api_reference/flax.linen/layers.rst#pooling to a new file called docs_nnx/api_reference/flax.nnx/nn/pooling.rst to have up to date nnx documentation

If you agree with the changes, I can send a PR with them.

cgarciae commented 1 week ago

@jorisSchaller happy to review the PR!