Split flax.py into save.py, load.py, and utils.py for readability
save.py contains serialize
load.py contains deserialize
utils.py contains both flatten_dict and unflatten_dict
Add freeze_dict param to unflatten_dict to either convert it to FrozenDict or keep it as a Dict (used for flax)
Update unit tests with pytest to cover every safejax function
Test dm-haiku model param serialization over haiku.nets.ResNet50
Add more examples/ for both flax and dm-haiku
🐛 Bug Fixes
Fix bug while unflattening dictionaries in unflatten_dict due to a variable being overwritten
🧪 Tests
[X] Did you implement unit tests if required?
If the above checkbox is checked, describe how you unit-tested it.
Add some assertions to make sure both safejax.utils.flatten_dict and safejax.utils.unflatten_dict work as expected to avoid bug mentioned above with unflatten_dict
Add some more unit tests for safejax.load and safejax.save due to the recent split of both files
✨ Features
flax.py
intosave.py
,load.py
, andutils.py
for readabilitysave.py
containsserialize
load.py
containsdeserialize
utils.py
contains bothflatten_dict
andunflatten_dict
freeze_dict
param tounflatten_dict
to either convert it toFrozenDict
or keep it as aDict
(used forflax
)pytest
to cover everysafejax
functiondm-haiku
model param serialization overhaiku.nets.ResNet50
examples/
for bothflax
anddm-haiku
🐛 Bug Fixes
unflatten_dict
due to a variable being overwritten🧪 Tests
If the above checkbox is checked, describe how you unit-tested it.
safejax.utils.flatten_dict
andsafejax.utils.unflatten_dict
work as expected to avoid bug mentioned above withunflatten_dict
safejax.load
andsafejax.save
due to the recent split of both files