google / objax

Apache License 2.0
769 stars 77 forks source link

`objax.variable.VarCollection.update` fails when passing `Dict[str, Any]` #253

Open alvarobartt opened 1 year ago

alvarobartt commented 1 year ago

Hi everyone! Thanks for the awesome work with objax and the JAX environment, and happy holidays!

I was playing around for objax for a bit, and realized that if you try to update the model.vars() which is a VarCollection using the VarCollection.update method overwriting the default dict.update method, if what you pass to the function is a Python dictionary and not a VarCollection it fails, as it's being cast into a Python list, and then we're trying to loop over the items of a list as if it was a Python dictionary, so it throws a ValueError: too many values to unpack (expected 2).

https://github.com/google/objax/blob/53b391bfa72dc59009c855d01b625049a35f5f1b/objax/variable.py#L311-L318

Is this intended? Shouldn't VarCollection.update just loop over classes that allow .items()?

alvarobartt commented 1 year ago

Hi @AlexeyKurakin (sorry for mentioning you here in case you're not actively working on this), can you confirm whether the expected behavior of .update is to also allow VarCollection, because according to the typing of the .update function it is, I can work on this issue as well as the referenced one if you want 😄 Thanks!