Add tests/utils.py and tests/__init__.py to import it as import .utils
Implement assert_over_trees function in tests/utils.py
Add chex assertions to jax, flax, and haiku params
๐ Linked Issue/s
24
๐ Notes
jaxtyping is another potential library to assert over trees
objax.variable.VarCollection is not yet supported as still needs more refinement
๐งช Tests
[X] Did you implement unit tests if required?
If the above checkbox is checked, describe how you unit-tested it.
Add assertions over trees with chex from @deepmind to make sure that the original params dict matches the one serialized and deserialized using safejax with safetensors as the tensor storage format.
โจ Features
tests/utils.py
andtests/__init__.py
to import it asimport .utils
assert_over_trees
function intests/utils.py
chex
assertions tojax
,flax
, andhaiku
params๐ Linked Issue/s
24
๐ Notes
jaxtyping
is another potential library to assert over treesobjax.variable.VarCollection
is not yet supported as still needs more refinement๐งช Tests
If the above checkbox is checked, describe how you unit-tested it.
Add assertions over trees with
chex
from @deepmind to make sure that the original params dict matches the one serialized and deserialized usingsafejax
withsafetensors
as the tensor storage format.