google / flax

Flax is a neural network library for JAX that is designed for flexibility.
https://flax.readthedocs.io
Apache License 2.0
5.78k stars 610 forks source link

[nnx] add PathContains Filter #4011

Closed cgarciae closed 1 week ago

cgarciae commented 2 weeks ago

What does this PR do?

Adds the PathContains Filter which lets users select paths that contain a specific key. Example

class Model(nnx.Module):
  def __init__(self, rngs):
    self.backbone = nnx.Linear(2, 3, rngs=rngs)
    self.head = nnx.Linear(3, 10, rngs=rngs)

model = Model(nnx.Rngs(0))

head_state = nnx.state(model, nnx.PathContains('head'))

assert 'head' in head_state
assert 'backbone' not in head_state