Currently FlaxPretrainedModel only supports loading pretrained models from sharded PyTorch weights or single-file .safetensors. It's worth adding support for loading sharded .safetensors.
Motivation
Recent open-source language models trained with PyTorch are likely only to release sharded .safetensors weights. The lack of support for loading from the dominating paradigm makes it troublesome to use these models in Jax.
Your contribution
I'm relatively new to the implementation of the saving & loading mechanism of transformers, but I'd love to try to work with this feature if the core team doesn't have enough bandwidth.
Feature request
Currently
FlaxPretrainedModel
only supports loading pretrained models from sharded PyTorch weights or single-file.safetensors
. It's worth adding support for loading sharded.safetensors
.Motivation
Recent open-source language models trained with PyTorch are likely only to release sharded
.safetensors
weights. The lack of support for loading from the dominating paradigm makes it troublesome to use these models in Jax.Your contribution
I'm relatively new to the implementation of the saving & loading mechanism of
transformers
, but I'd love to try to work with this feature if the core team doesn't have enough bandwidth.