google-deepmind / optax

Optax is a gradient processing and optimization library for JAX.
https://optax.readthedocs.io
Apache License 2.0
1.71k stars 194 forks source link

`ocp.tree.serialize_tree` filtering logic for sequences with empty leaves #1139

Closed JesseFarebro closed 1 week ago

JesseFarebro commented 1 week ago

Hi,

I spotted ocp.tree.serialize_tree but it seems the serialization logic won't work if you have empty leaves within a sequence. This happens quite frequently with optax where you'll end up with optax.EmptyState() within a tuple. Here's a minimal reproduction of this issue:

import optax
import orbax.checkpoint as ocp

tree = (0, optax.EmptyState(), 1) # or None, etc.
ocp.tree.serialize_tree(tree)

resulting in:

  File .../orbax/checkpoint/tree/utils.py", line 79, in _extend_list
    assert idx <= len(ls)
           ^^^^^^^^^^^^^^
AssertionError

I'm not sure what the ideal solution here is, I don't have enough context on what's the intended purpose of serialize_tree and deserialize_tree.

vroulet commented 1 week ago

Hello @JesseFarebro

I believe this issue will be best answered on the orbax github. I don't have context either on the intended purpose of serialize_tree and deserialize_tree.

If an answer from the orbax team let you believe that the issue is in optax, feel free to reopen the issue. For now, I'll close it.

JesseFarebro commented 1 week ago

Oops, I definitely thought I had posted this to the Orbax repository. I guess the two project names are too close for me to differentiate with many tabs open. Sorry for the inconvenience 😅