keras-team / keras

Deep Learning for humans
http://keras.io/
Apache License 2.0
61.86k stars 19.44k forks source link

[Feature Request] Support pytree (nested list) in optimizer `build` #18443

Open refraction-ray opened 1 year ago

refraction-ray commented 1 year ago
w = jax.numpy.ones([4, 1])
b1 = jax.numpy.ones([1])
b2 = jax.numpy.ones([1])

opt = keras.optimizers.Adam(1e-2)
opt.build([w, b1, b2]) # ok
opt = keras.optimizers.Adam(1e-2)
opt.build([w, [b1, b2]]) # failed
# AttributeError: 'list' object has no attribute 'shape'

The latter case is very common, when one use functional programming paradigm, as model.variables is a list of tensors (similar to b1, b2 above), and there could be some other variables outside the model (similar to w above) that the user also want to optimize together. A full pytree support in optimizer.build would be fantastic to use.

sachinprasadhs commented 7 months ago

Now, both the cases are failing with the error AttributeError: 'tensorflow.python.framework.ops.EagerTensor' object has no attribute '_unique_id'

Attached Gist here for reference

fchollet commented 7 months ago
  1. Why not just call flatten on your structure before calling build()?
  2. If you pass a nested structure to build(), do you expect to also pass the same nested structure in stateless_apply(optimizer_variables, grads, trainable_variables)? (as optimizer_variables)
refraction-ray commented 7 months ago

@fchollet

  1. this question can apply to every pytree compatible API, it is just more elegant and easy to use for an API directly accepts pytree structures, as most APIs in keras did.
  2. yes.