Closed copybara-service[bot] closed 3 months ago
Add functool.wraps() annotation to flax.nn.jit.
At the moment, all the jit names in a jaxpr show up as "jitted". functools.partial does not forward names.
Add functool.wraps() annotation to flax.nn.jit.
At the moment, all the jit names in a jaxpr show up as "jitted". functools.partial does not forward names.