Originally posted by **yklcs** June 1, 2024
[`optax.multi_transform`](https://optax.readthedocs.io/en/latest/api/combining_optimizers.html#multi-transform) defines multiple transforms with a `Mapping[Hashable, GradientTransformation]` and uses a PyTree or function to map parameters to the key.
Using `optax.multi_transform` with `nnx.Optimizer` means said mapping of type `nnx.State` is needed.
`nnx.State` is typed to use `StateLeaf` which means we can't use string or integer keys.
While ignoring typing does work, it feels brittle and might end up broken later.
Is there any other solution for this problem?
```python
tx = optax.multi_transform(
{
"weights": optax.adamw(learning_rate, momentum),
"biases": optax.adamw(learning_rate, momentum),
},
# this doesn't work:
# {
# "weights": "weights",
# "biases": "biases",
# },
# this does, but is it safe?:
nnx.State({
"weights": "weights",
"biases": "biases"
})
}
)
```
Discussed in https://github.com/google/flax/discussions/3954