huggingface / transformers

🤗 Transformers: State-of-the-art Machine Learning for Pytorch, TensorFlow, and JAX.
https://huggingface.co/transformers
Apache License 2.0
132.07k stars 26.31k forks source link

Support `from_pretrained` of `FlaxPretrainedModel` from sharded `.safetensors` weights #32200

Open dest1n1s opened 1 month ago

dest1n1s commented 1 month ago

Feature request

Currently FlaxPretrainedModel only supports loading pretrained models from sharded PyTorch weights or single-file .safetensors. It's worth adding support for loading sharded .safetensors.

Motivation

Recent open-source language models trained with PyTorch are likely only to release sharded .safetensors weights. The lack of support for loading from the dominating paradigm makes it troublesome to use these models in Jax.

Your contribution

I'm relatively new to the implementation of the saving & loading mechanism of transformers, but I'd love to try to work with this feature if the core team doesn't have enough bandwidth.

LysandreJik commented 1 month ago

Thanks for the offer @dest1n1s! It would be nice to have support for this in Flax, indeed. Please feel free to open a PR to offer that!

cc @sanchit-gandhi if you can review once the PR is open.