Open prabhudavidsheryl opened 2 months ago
I am having a similar issue with Orbax-checkpoint in flax.linen
. I followed the tutorial in flax.linen
for Dropout
. The TrainState
in that case included another attribute key
, which is the key for dropout. When saving the TrainState
with orbax_checkpoint
, I got an error of TypeError: Cannot interpret 'key<fry>' as a data type
.
Hence, it is not about flax.nnx
, but also flax.linen
as well.
I have been trying to use Orbax for checkpointing Flax NNX models and getting checkpointing to work for models with Dropout layers which also hold JAX RNG keys is not very straight forward. After various attempts this was the only way I could get it to work.
The comments describe the issues faced.
It would be good to address ease of use for cases where the model has Dropout layers.