Closed floringogianu closed 8 months ago
Hi, @floringogianu! Thank you for this! Do you think it would be better to use field from flax.struct? It seems to be just a wrapper, but it is potentially better compatible with PyTreeNode
Hi, you are right, it should use flax.struct.field
. I did the changes, let me know if I should squash the commits.
This addresses #3 . Checked on
python 3.12
andjax 0.4.23
.