Open stas00 opened 3 years ago
Would like to ping @Rocketknight1 regarding the TensorFlow management of types, and @patil-suraj for flax
This should work in Tensorflow too - you can use tf.dtypes.as_dtype(dtype_string)
to turn strings into TF dtype objects.
@Rocketknight1 Sorry, but can you please elaborate on how to load the model in Tensorflow or point me in the right direction? I am new to hugging face and I have been looking all over for instructions on how to do it. Thank you.
This is a split off from one of the discussions at https://github.com/huggingface/transformers/pull/13209:
torch_dtype
tofrom_pretrained
andfrom_config
.torch_dtype
in the config file for future possibly automatic loading model in the optimal "regime".torch.dtype
at other times a string like "float32" as we can't storetorch.dtype
in json.dtype
is really the same across pt/tf/flux and perhaps we should just usedtype
in the config and variables and have it consistently to be a string ("float32") and convert it to the right dtype object of the desired framework at the point of use, e.g.getattr(torch, "float32")
A possible solution is to deprecate
torch_dtype
and replace it withdtype
string both in config and in the function argument.Possible conflicts with the naming:
we already have the
dtype
attribute in modeling_utils, which returnstorch.dtype
based on the first param's dtype.https://github.com/huggingface/transformers/blob/master/src/transformers/modeling_utils.py#L205
The context is different, but still this is something to consider to avoid ambiguity.
I may have missed some other areas. So please share if something else needs to be added.
Additional notes:
@LysandreJik, @sgugger, @patrickvonplaten