issues
search
google
/
flax
Flax is a neural network library for JAX that is designed for flexibility.
https://flax.readthedocs.io
Apache License 2.0
5.78k
stars
610
forks
source link
[nnx] fix grad
#4007
Closed
cgarciae
closed
1 week ago
cgarciae
commented
2 weeks ago
What does this PR do?
Fixes a bug in
nnx.grad
where
grad
was not creating new references for non differentiable NNX objects found in
*args
.
Allows
vmap
's
state_axes
mapping to accept
None
axes.
What does this PR do?
nnx.grad
wheregrad
was not creating new references for non differentiable NNX objects found in*args
.vmap
'sstate_axes
mapping to acceptNone
axes.