instadeepai / nucleotide-transformer

🧬 Nucleotide Transformer: Building and Evaluating Robust Foundation Models for Human Genomics
https://www.biorxiv.org/content/10.1101/2023.01.11.523679v2
Other
480 stars 55 forks source link

A little bit different embeddings between jax and pytorch models #77

Closed nzhang89 closed 3 months ago

nzhang89 commented 3 months ago

Thank you for your excellent work on releasing the nucleotide transformer model. I am trying to use it for a downstream regression task. Since my downstream model is written in pytorch, I am considering the huggingface models.

However, I found that the embeddings from the jax model and the huggingface model seem to be a little bit different.

Here is an example.

The jax model

import haiku as hk
import jax
import jax.numpy as jnp
from nucleotide_transformer.pretrained import get_pretrained_model

parameters, forward_fn, tokenizer, config = get_pretrained_model(
    model_name="500M_human_ref",
    embeddings_layers_to_save=(20,),
    max_positions=32
)
forward_fn = hk.transform(forward_fn)

sequences = ["ATTCCGATTCCGATTCCG", "ATTTCTCTCTCTCTCTGAGATCGATCGATCGAT"]
tokens_ids = [b[1] for b in tokenizer.batch_tokenize(sequences)]
tokens_str = [b[0] for b in tokenizer.batch_tokenize(sequences)]
tokens = jnp.asarray(tokens_ids, dtype=jnp.int32)

random_key = jax.random.PRNGKey(0)
outs = forward_fn.apply(parameters, random_key, tokens)

The embedding for the first sequence (first 4 tokens)

>>> outs['embeddings_20'][0][0:4]
Array([[  0.5897,   7.0619,   8.0574, ..., -12.8165,  -8.5919,   3.5711],
       [ -3.8133,  -3.4373,  -3.432 , ...,   8.1626,  -3.5501,  -1.0834],
       [ -4.4258,  -3.113 ,  -2.7153, ...,   8.8512,  -3.9107,  -1.6326],
       [ -1.6652,  -4.6087,  -2.1508, ...,   8.9233,  -2.1052,  -0.5005]],      dtype=float32)

The pytorch model

from transformers import AutoTokenizer, AutoModelForMaskedLM
import torch

tokenizer = AutoTokenizer.from_pretrained("InstaDeepAI/nucleotide-transformer-500m-human-ref")
model = AutoModelForMaskedLM.from_pretrained("InstaDeepAI/nucleotide-transformer-500m-human-ref")
model.eval()

max_length = 32

sequences = ["ATTCCGATTCCGATTCCG", "ATTTCTCTCTCTCTCTGAGATCGATCGATCGAT"]
tokens_ids = tokenizer.batch_encode_plus(sequences, return_tensors="pt", 
    padding="max_length", max_length = max_length)["input_ids"]

attention_mask = tokens_ids != tokenizer.pad_token_id
with torch.no_grad():
    torch_outs = model(
        tokens_ids,
        attention_mask=attention_mask,
        encoder_attention_mask=attention_mask,
        output_hidden_states=True
    )

The embedding for the first sequence (first 4 tokens)

>>> torch_outs['hidden_states'][20][0][0:4]
tensor([[  0.5936,   7.0626,   8.0543,  ..., -12.8166,  -8.5917,   3.5761],
        [ -3.8114,  -3.4322,  -3.4315,  ...,   8.1701,  -3.5505,  -1.0878],
        [ -4.4262,  -3.1080,  -2.7096,  ...,   8.8625,  -3.9117,  -1.6279],
        [ -1.6622,  -4.6036,  -2.1546,  ...,   8.9295,  -2.1016,  -0.5010]])

The embeddings are a little bit different. Is this expected? Do you know the possible reason that might cause the difference? Thank you so much for your help!

dallatt commented 3 months ago

Hello @nzhang89,

The indexing in github is 1-based whereas the indexing in the hidden_states array of HuggingFace is 0-based. This means that in order to get the same embedding than github, you need to extract torch_outs['hidden_states'][19]. That is why the embeddings are slightly different, they are extracted 1 layer apart !

Best regards, Hugo

nzhang89 commented 3 months ago

Hi @dallatt,

Thank you for your quick response. I tried with torch_outs['hidden_states'][19] but got completely different results.

>>> torch_outs['hidden_states'][20][0][0:4]
tensor([[  0.5936,   7.0626,   8.0543,  ..., -12.8166,  -8.5917,   3.5761],
        [ -3.8114,  -3.4322,  -3.4315,  ...,   8.1701,  -3.5505,  -1.0878],
        [ -4.4262,  -3.1080,  -2.7096,  ...,   8.8625,  -3.9117,  -1.6279],
        [ -1.6622,  -4.6036,  -2.1546,  ...,   8.9295,  -2.1016,  -0.5010]])

>>> torch_outs['hidden_states'][19][0][0:4]
tensor([[  2.0145,   6.4966,   7.4038,  ..., -14.8354,  -7.2971,   1.8059],
        [ -3.0252,   0.5187,  -0.5540,  ...,   1.1396,  -3.8698,   2.3325],
        [ -3.7960,   0.3949,  -0.9983,  ...,   2.2104,  -4.1310,   1.4081],
        [ -1.7485,  -1.2718,   0.1990,  ...,   2.0984,  -2.3466,   2.3011]])

The torch_outs['hidden_states'] has length of 25. I guess the first one might be the output after the embedding layer, which is consistent with original ESM model setting.