instadeepai / nucleotide-transformer

🧬 Nucleotide Transformer: Building and Evaluating Robust Foundation Models for Human Genomics
https://www.biorxiv.org/content/10.1101/2023.01.11.523679v2
Other
481 stars 55 forks source link

v2 checkpoints for TensorFlow #44

Closed felbecker closed 9 months ago

felbecker commented 9 months ago

Hi,

thank you for your incredible work!

How much work would it be to generate model checkpoints for your updated v2 models that can be loaded and finetuned in TensorFlow and do you, by chance, have plans to support TensorFlow in the future?

Best, Felix

dallatt commented 9 months ago

Hello @felbecker ,

HuggingFace automatically converts the models to tensorflow safetensors. We just updated the repositories with them, you can find one of the v2 model's tensorflow version here.

Best regards,

Hugo

felbecker commented 9 months ago

Great, thanks!

LarsGab commented 9 months ago

Hi,

I encountered an issue while loading the tensorflow weights for the v2 model, I tried to load them in two ways:

  1. Using TFAutoModelForMaskedLM (safe_tensors and trust_remote_code enabeled). I got following error:

    ValueError: Unrecognized configuration class <class 'transformers_modules.InstaDeepAI.nucleotide-transformer-v2-50m-multi-species.af10a726a544d702d43e1de84548c7f1bc30cc8f.esm_config.EsmConfig'> for this kind of AutoModel: TFAutoModelForMaskedLM.
  2. Using TFEsmForMaskedLM, I got:

Traceback (most recent call last):
  File "/opt/conda/lib/python3.9/site-packages/transformers/modeling_tf_pytorch_utils.py", line 348, in load_pytorch_state_dict_in_tf2_model
    array = apply_transpose(transpose, array, symbolic_weight.shape)
  File "/opt/conda/lib/python3.9/site-packages/transformers/modeling_tf_pytorch_utils.py", line 143, in apply_transpose
    weight = reshape(weight, match_shape)
  File "/opt/conda/lib/python3.9/site-packages/transformers/utils/generic.py", line 620, in reshape
    return tf.reshape(array, newshape)
  File "/opt/conda/lib/python3.9/site-packages/tensorflow/python/util/traceback_utils.py", line 153, in error_handler
    raise e.with_traceback(filtered_tb) from None
  File "/opt/conda/lib/python3.9/site-packages/tensorflow/python/eager/execute.py", line 54, in quick_execute
    tensors = pywrap_tfe.TFE_Py_Execute(ctx._handle, device_name, op_name,
tensorflow.python.framework.errors_impl.InvalidArgumentError: {{function_node __wrapped__Reshape_device_/job:localhost/replica:0/task:0/device:GPU:0}} Input to reshape is a tensor with 2097152 values, but the requested shape has 1048576 [Op:Reshape]

Is there a different way to load the model? Thanks for your help! Best, Lars