google / flax

Flax is a neural network library for JAX that is designed for flexibility.
https://flax.readthedocs.io
Apache License 2.0
5.79k stars 609 forks source link

`optax.multi_transform` + `nnx.State`/`nnx.Optimizer` troubles #3955

Closed cgarciae closed 4 weeks ago

cgarciae commented 1 month ago

Discussed in https://github.com/google/flax/discussions/3954

Originally posted by **yklcs** June 1, 2024 [`optax.multi_transform`](https://optax.readthedocs.io/en/latest/api/combining_optimizers.html#multi-transform) defines multiple transforms with a `Mapping[Hashable, GradientTransformation]` and uses a PyTree or function to map parameters to the key. Using `optax.multi_transform` with `nnx.Optimizer` means said mapping of type `nnx.State` is needed. `nnx.State` is typed to use `StateLeaf` which means we can't use string or integer keys. While ignoring typing does work, it feels brittle and might end up broken later. Is there any other solution for this problem? ```python tx = optax.multi_transform( { "weights": optax.adamw(learning_rate, momentum), "biases": optax.adamw(learning_rate, momentum), }, # this doesn't work: # { # "weights": "weights", # "biases": "biases", # }, # this does, but is it safe?: nnx.State({ "weights": "weights", "biases": "biases" }) } ) ```