google / flax

Flax is a neural network library for JAX that is designed for flexibility.
https://flax.readthedocs.io
Apache License 2.0
6.01k stars 634 forks source link

Add functool.wraps() annotation to flax.nn.jit. #4051

Closed copybara-service[bot] closed 3 months ago

copybara-service[bot] commented 3 months ago

Add functool.wraps() annotation to flax.nn.jit.

At the moment, all the jit names in a jaxpr show up as "jitted". functools.partial does not forward names.