huggingface / transformers

🤗 Transformers: State-of-the-art Machine Learning for Pytorch, TensorFlow, and JAX.
https://huggingface.co/transformers
Apache License 2.0
128.29k stars 25.45k forks source link

[QoL] Allow dtype str for torch_dtype arg of from_pretrained #31590

Closed aliencaocao closed 1 day ago

aliencaocao commented 3 days ago

What does this PR do?

This is to align with the behavior of pipeline where users can pass strings like torch_dtype="float32" and it gets recognized as torch.float32. It now also happens for any PretrainedModel.from_pretrained.

Before submitting

Who can review?

@amyeroberts

aliencaocao commented 2 days ago

@amyeroberts added them

HuggingFaceDocBuilderDev commented 1 day ago

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.