alvarobartt / safejax

Serialize JAX, Flax, Haiku, or Objax model params with 🤗`safetensors`
https://alvarobartt.github.io/safejax/
MIT License
42 stars 5 forks source link

Define custom type in `safejax.typing` for model params #11

Closed alvarobartt closed 1 year ago

alvarobartt commented 1 year ago

Something like:

ParamsLike = Union[Dict[str, jnp.DeviceArray], Dict[str, np.ndarray], FrozenDict, VarCollection]

To avoid defining all the types over and over, and also so that the type-hints are aligned between both safejax.serialize and safejax.deserialize :hugs: