alvarobartt / safejax

Serialize JAX, Flax, Haiku, or Objax model params with ๐Ÿค—`safetensors`
https://alvarobartt.github.io/safejax/
MIT License
42 stars 5 forks source link

๐Ÿงช Add assertions over `pytrees` with `chex` #27

Closed alvarobartt closed 1 year ago

alvarobartt commented 1 year ago

โœจ Features

๐Ÿ”— Linked Issue/s

24

๐Ÿ“Œ Notes

๐Ÿงช Tests

If the above checkbox is checked, describe how you unit-tested it.

Add assertions over trees with chex from @deepmind to make sure that the original params dict matches the one serialized and deserialized using safejax with safetensors as the tensor storage format.