huggingface / transformers

🤗 Transformers: State-of-the-art Machine Learning for Pytorch, TensorFlow, and JAX.
https://huggingface.co/transformers
Apache License 2.0
135.58k stars 27.14k forks source link

[model loading] framework-agnostic dtype parameter #13246

Open stas00 opened 3 years ago

stas00 commented 3 years ago

This is a split off from one of the discussions at https://github.com/huggingface/transformers/pull/13209:

  1. It all started with trying to load torch models under either the desired dtype or the the dtype of the pretrained model - and thus avoid 2x memory usage needs e.g. if the model needs to be just fp16. So we added torch_dtype to from_pretrained and from_config.
  2. Then we started storing torch_dtype in the config file for future possibly automatic loading model in the optimal "regime".
  3. This resulted in a discrepancy where the same symbol sometimes means torch.dtype at other times a string like "float32" as we can't store torch.dtype in json.
  4. then in https://github.com/huggingface/transformers/pull/13209#discussion_r693292542 we started discussing how dtype is really the same across pt/tf/flux and perhaps we should just use dtype in the config and variables and have it consistently to be a string ("float32") and convert it to the right dtype object of the desired framework at the point of use, e.g. getattr(torch, "float32")

A possible solution is to deprecate torch_dtype and replace it with dtype string both in config and in the function argument.

Possible conflicts with the naming:

  1. we already have the dtype attribute in modeling_utils, which returns torch.dtype based on the first param's dtype.

    https://github.com/huggingface/transformers/blob/master/src/transformers/modeling_utils.py#L205

    The context is different, but still this is something to consider to avoid ambiguity.

I may have missed some other areas. So please share if something else needs to be added.

Additional notes:

@LysandreJik, @sgugger, @patrickvonplaten

LysandreJik commented 3 years ago

Would like to ping @Rocketknight1 regarding the TensorFlow management of types, and @patil-suraj for flax

Rocketknight1 commented 3 years ago

This should work in Tensorflow too - you can use tf.dtypes.as_dtype(dtype_string) to turn strings into TF dtype objects.

Joy-Lunkad commented 3 years ago

@Rocketknight1 Sorry, but can you please elaborate on how to load the model in Tensorflow or point me in the right direction? I am new to hugging face and I have been looking all over for instructions on how to do it. Thank you.