alvarobartt / safejax

Serialize JAX, Flax, Haiku, or Objax model params with 🤗`safetensors`
https://alvarobartt.github.io/safejax/
MIT License
42 stars 5 forks source link

✨ Add `metadata` in `safejax.serialize` #26

Closed alvarobartt closed 1 year ago

alvarobartt commented 1 year ago

✨ Features

🧪 Tests

If the above checkbox is checked, describe how you unit-tested it.

Add unit tests when calling serialize with filename, and make sure it's loaded back on deserialize.

alvarobartt commented 1 year ago

According to what @narsil said at https://github.com/huggingface/safetensors/issues/147#issuecomment-1370237198, load_file is usually a small wrapper over safe_open, which means that probably it's worth it to compare the speed and resource consumption of safe_open versus load and load_file functions so as to see whether we can indeed replace those with safe_open calls.

Note that load replacement is most likely not possible if the file is not being saved in between save and safe_open as it expects a file, but we can specify that metadata can just be stored if dumping safetensors file, otherwise it's lost.